import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
from pfjax import sde as sde
# --- main functions -----------------------------------------------------------
[docs]class PGNETModel(sde.SDEModel):
r"""
Prokaryotic auto-regulatory gene network Model.
The base model involves differential equations of the chemical reactions:
::
DNA + P2 --> DNA_P2
DNA_P2 --> DNA + P2
DNA --> DNA + RNA
RNA --> RNA + P
P + P --> P2
P2 --> P + P
RNA --> 0
P --> 0
These equations are associated with a parameter in `theta = (theta0, ..., theta7)`.
The model is approximated by a SDE described in Golightly & Wilkinson (2005).
A particular restriction on the chemical reactions is by the conservation law which implies that `DNA + DNA_P2 = K`.
Thus the SDE model can be described in terms of `x_t = (RNA, P, P2, DNA)`.
Then assuming a standard form of the SDE, the base model can be written as
::
x_mt = x_{m, t-1} + mu_mt dt/m + Sigma_mt^{1/2} dt/m
y_t ~ N( exp(x_{m,mt}), diag(tau^2) )
Ito's Lemma is applied to transform the base model on the log-scale to allow for unconstrained variables
::
logx_mt = log(x_mt)
so `mu_mt` and `Sigma_mt` are transformed accordingly.
- Model parameters: `theta = (theta0, ... theta7, tau0, ... tau3)`.
- Global constants: `dt` and `n_res`, i.e., `m`.
- State dimensions: `n_state = (n_res, 4)`.
- Measurement dimensions: `n_meas = 4`.
Args:
dt: SDE interobservation time.
n_res: SDE resolution number. There are `n_res` latent variables per observation, equally spaced with interobservation time `dt/n_res`.
bootstrap (bool): Flag indicating whether to use a Bootstrap particle filter or a bridge filter.
"""
def __init__(self, dt, n_res, bootstrap=True):
# creates "private" variables self._dt and self._n_res
super().__init__(dt, n_res, diff_diag=False)
self._n_state = (self._n_res, 4)
self._K = 10
self._bootstrap = bootstrap
def _drift(self, x, theta):
"""
Calculate the drift on the original scale.
"""
mu1 = theta[2]*x[3] - theta[6]*x[0]
mu2 = 2*theta[5]*x[2] - theta[7]*x[1] + \
theta[3]*x[0] - theta[4]*x[1]*(x[1]-1)
mu3 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2] - \
theta[5]*x[2] + 0.5*theta[4]*x[1]*(x[1]-1)
mu4 = theta[1]*(self._K-x[3]) - theta[0]*x[3]*x[2]
mu = jnp.stack([mu1, mu2, mu3, mu4])
return mu
def _diff(self, x, theta):
"""
Calculate the diffusion matrix on the original scale.
"""
A = theta[0]*x[3]*x[2] + theta[1]*(self._K-x[3])
sigma11 = theta[2]*x[3] + theta[6]*x[0]
sigma_max = jnp.where(0 < x[1]*(x[1]-1), x[1]*(x[1]-1), 0)
sigma_max = x[1]*(x[1]-1)
sigma22 = theta[7]*x[1] + 4*theta[5]*x[2] + \
theta[3]*x[0] + 2*theta[4]*sigma_max
sigma23 = -2*theta[5]*x[2] - theta[4]*sigma_max
sigma33 = A + theta[5]*x[2] + 0.5*theta[4]*sigma_max
sigma34 = A
sigma44 = A
Sigma = jnp.array([[sigma11, 0., 0., 0.],
[0., sigma22, sigma23, 0.],
[0., sigma23, sigma33, sigma34],
[0., 0, sigma34, sigma44]])
return Sigma
[docs] def drift(self, x, theta):
"""
Calculates the SDE drift function on the log scale.
"""
x = jnp.exp(x)
# K = self._K
mu = self._drift(x, theta)
Sigma = self._diff(x, theta)
#f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(K-x[3])])
#f_pp = jnp.array([-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3] + 1/(K-x[3])/(K-x[3])])
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
f_pp = jnp.array(
[-1/x[0]/x[0], -1/x[1]/x[1], -1/x[2]/x[2], -1/x[3]/x[3]]
)
mu_trans = f_p * mu + 0.5 * f_pp * jnp.diag(Sigma)
return mu_trans
[docs] def diff(self, x, theta):
"""
Calculates the SDE diffusion function on the log scale.
"""
x = jnp.exp(x)
# K = self._K
Sigma = self._diff(x, theta)
#f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3] + 1/(K-x[3])])
f_p = jnp.array([1/x[0], 1/x[1], 1/x[2], 1/x[3]])
Sigma_trans = jnp.outer(f_p, f_p) * Sigma
return Sigma_trans
[docs] def meas_lpdf(self, y_curr, x_curr, theta):
"""
Log-density of `p(y_curr | x_curr, theta)`.
Args:
y_curr: Measurement variable at current time `t`.
x_curr: State variable at current time `t`.
theta: Parameter value.
Returns:
The log-density of `p(y_curr | x_curr, theta)`.
"""
tau = theta[8:12]
return jnp.sum(
jsp.stats.norm.logpdf(y_curr, loc=jnp.exp(x_curr[-1]), scale=tau)
)
[docs] def meas_sample(self, key, x_curr, theta):
"""
Sample from `p(y_curr | x_curr, theta)`.
Args:
x_curr: State variable at current time `t`.
theta: Parameter value.
key: PRNG key.
Returns:
Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
"""
tau = theta[8:12]
return jnp.exp(x_curr[-1]) + tau * random.normal(key, (self._n_state[1],))
[docs] def pf_init(self, key, y_init, theta):
"""
Particle filter calculation for `x_init`.
Samples from an importance sampling proposal distribution
::
x_init ~ q(x_init) = q(x_init | y_init, theta)
and calculates the log weight
::
logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
**FIXME:** Explain what the proposal is and why it gives `logw = CONST`.
Args:
y_init: Measurement variable at initial time `t = 0`.
theta: Parameter value.
key: PRNG key.
Returns:
Tuple:
- x_init: A sample from the proposal distribution for `x_init`.
- logw: The log-weight of `x_init`.
"""
tau = theta[8:12]
# key, subkey = random.split(key)
# x_init = jnp.log(y_init +
# tau * random.normal(subkey, (self.n_state[1],)))
# return \
# jnp.append(jnp.zeros((self.n_res-1,) + x_init.shape),
# jnp.expand_dims(x_init, axis=0), axis=0), \
# jnp.zeros(())
key, subkey = random.split(key)
x_init = jnp.log(y_init + tau * random.truncated_normal(
subkey,
lower=-y_init/tau,
upper=jnp.inf,
shape=(self._n_state[1],)
))
logw = jnp.sum(jsp.stats.norm.logcdf(y_init/tau))
#x_init = theta[12:16]
#logw = -jnp.float_(0)
return \
jnp.append(jnp.zeros((self._n_res-1,) + x_init.shape),
jnp.expand_dims(x_init, axis=0), axis=0), \
logw
[docs] def pf_step(self, key, x_prev, y_curr, theta):
"""
Get particles at subsequent steps.
Args:
x_prev: State variable at previous time `t-1`.
y_curr: Measurement variable at current time `t`.
theta: Parameter value.
key: PRNG key.
Returns:
Tuple:
- x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
- logw: The log-weight of `x_curr`.
"""
if self._bootstrap:
x_curr, logw = super().pf_step(key, x_prev, y_curr, theta)
else:
omega = (theta[8:12] / y_curr)**2
x_curr, logw = self.bridge_prop(
key, x_prev, y_curr, theta,
jnp.log(y_curr), jnp.eye(4), jnp.diag(omega)
)
return x_curr, logw