"""
Generic methods for SDEs.
**Notes:**
- `euler_{sim/lpdf}_{diag/var}()` are designed to be used independently of the SDE base class. If we abandon this requirement then we can avoid quite a bit of code duplication by having these functions do one data point only and putting the `vmap`/`scan` constructs into the `state_{sample/lpdf}` methods.
"""
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
from pfjax import mvn_bridge as mb
from pfjax.utils import *
[docs]def euler_sim_diag(key, x, dt, drift, diff, theta):
"""
Simulate SDE with diagonal diffusion using Euler-Maruyama discretization.
Args:
key: PRNG key.
x: Initial value of the SDE. A vector of size `n_dims`.
dt: Interobservation time.
drift: Drift function having signature `drift(x, theta)` and returning a vector of size `n_dims`.
diff: Diffusion function having signature `diff(x, theta)` and returning a vector of size `n_dims`.
theta: Parameter value.
Returns:
Simulated SDE values. A vector of size `n_dims`.
"""
dr = x + drift(x, theta) * dt
df = diff(x, theta) * jnp.sqrt(dt)
return dr + df * random.normal(key, (x.shape[0],))
[docs]def euler_lpdf_diag(x_curr, x_prev, dt, drift, diff, theta):
"""
Calculate the log PDF of observations from an SDE with diagonal diffusion using the Euler-Maruyama discretization.
Args:
x_curr: Current SDE observations. A vector of size `n_dims`.
x_prev: Previous SDE observations. A vector of size `n_dims`.
dt: Interobservation time.
drift: Drift function having signature `drift(x, theta)` and returning a vector of size `n_dims`.
diff: Diffusion function having signature `diff(x, theta)` and returning a vector of size `n_dims`.
theta: Parameter value.
Returns:
The log-density of the SDE observations.
"""
return jsp.stats.norm.logpdf(
x=x_curr,
loc=x_prev + drift(x_prev, theta) * dt,
scale=diff(x_prev, theta) * jnp.sqrt(dt)
)
[docs]def euler_sim_var(key, x, dt, drift, diff, theta):
"""
Simulate SDE with dense diffusion using Euler-Maruyama discretization.
Args:
key: PRNG key.
x: Initial value of the SDE. A vector of size `n_dims`.
dt: Interobservation time.
drift: Drift function having signature `drift(x, theta)` and returning a vector of size `n_dims`.
diff: Diffusion function having signature `diff(x, theta)` and returning a vector of size `n_dims`.
theta: Parameter value.
Returns:
Simulated SDE values. A vector of size `n_dims`.
"""
# dr = x + drift(x, theta) * dt
# chol_Sigma = jnp.linalg.cholesky(diff(x, theta))
# df = chol_Sigma * jnp.sqrt(dt)
# return dr + jnp.matmul(df, random.normal(key, (x.shape[0],)))
return jax.random.multivariate_normal(
key,
mean=x + drift(x, theta) * dt,
cov=diff(x, theta) * dt
)
[docs]def euler_lpdf_var(x_curr, x_prev, dt, drift, diff, theta):
"""
Calculate the log PDF of observations from an SDE with dense diffusion using the Euler-Maruyama discretization.
Args:
x_curr: Current SDE observations. A vector of size `n_dims`.
x_prev: Previous SDE observations. A vector of size `n_dims`.
dt: Interobservation time.
drift: Drift function having signature `drift(x, theta)` and returning a vector of size `n_dims`.
diff: Diffusion function having signature `diff(x, theta)` and returning a vector of size `n_dims`.
theta: Parameter value.
Returns:
The log-density of the SDE observations.
"""
return jsp.stats.multivariate_normal.logpdf(
x=x_curr,
mean=x_prev + drift(x_prev, theta) * dt,
cov=diff(x_prev, theta) * dt
)
[docs]class SDEModel(object):
"""
Base class for SDE models.
This class sets up a PF model with methods `state_lpdf()`, `state_sim()`, and `pf_step()` automatically determined from user-specified SDE drift and diffusion functions.
For computational efficiency, the user can also specify whether or not the diffusion is diagonal. This will set up methods `euler_sim()` and `euler_lpdf()` supplied at instantiation time from either `euler_{sim/lpdf}_diag()` or `euler_{sim/lpdf}_var()`, with arguments identical to those of the free functions except `drift` and `diff`, which are supplied by `self.drift()` and `self.diff()`.
For `pf_step()`, a bootstrap filter is assumed by default, for which the user needs to specify `meas_lpdf()`.
**Notes:**
- Currently contains `_state_sample_for()` and `_state_lpdf_for()` for testing purposes. May want to move these elsewhere at some point to obfuscate from users...
- Should derive from `pf.BaseModel`.
"""
def __init__(self, dt, n_res, diff_diag):
r"""
Class constructor.
Args:
dt: SDE interobservation time.
n_res: SDE resolution number. There are `n_res` latent variables per observation, equally spaced with interobservation time `dt/n_res`.
diff_diag: Whether or not the diffusion matrix is assumed to be diagonal.
"""
self._dt = dt
self._n_res = n_res
# the following members are only used for testing
self._diff_diag = diff_diag
# also have self._n_state for testing,
# defined in the derived model class
if diff_diag:
# instantiate methods depending on whether the diffusion is diagonal
def euler_sim(self, key, x, dt, theta):
return euler_sim_diag(key, x, dt,
self.drift, self.diff, theta)
def euler_lpdf(self, x_curr, x_prev, dt, theta):
return euler_lpdf_diag(x_curr, x_prev, dt,
self.drift, self.diff, theta)
def diff_full(self, x, theta):
return jnp.diag(self.diff(x, theta))
setattr(self.__class__, 'euler_sim', euler_sim)
setattr(self.__class__, 'euler_lpdf', euler_lpdf)
setattr(self.__class__, 'diff_full', diff_full)
else:
def euler_sim(self, key, x, dt, theta):
return euler_sim_var(key, x, dt,
self.drift, self.diff, theta)
def euler_lpdf(self, x_curr, x_prev, dt, theta):
return euler_lpdf_var(x_curr, x_prev, dt,
self.drift, self.diff, theta)
def diff_full(self, x, theta):
return self.diff(x, theta)
setattr(self.__class__, 'euler_sim', euler_sim)
setattr(self.__class__, 'euler_lpdf', euler_lpdf)
setattr(self.__class__, 'diff_full', diff_full)
[docs] def is_valid(self, x, theta):
"""
Checks whether SDE observations are valid.
Args:
x: SDE variables. A vector of size `n_dims`.
theta: Parameter value.
Returns:
Whether or not `x` is valid SDE data.
"""
return jnp.array(True)
[docs] def is_valid_state(self, x, theta):
"""
Checks whether SDE latent variables are valid.
Applies `is_valid()` to each of the `n_res` SDE variables, and also checks for nans.
Args:
x: State variable of size `n_res x n_dims`.
theta: Parameter value.
Returns:
Whether or not all `n_res` latent variables are valid.
"""
valid_x = jax.vmap(self.is_valid, in_axes=(0, None))(x, theta)
nan_x = jnp.any(jnp.isnan(x), axis=1)
return jnp.alltrue(valid_x, where=~nan_x) & jnp.alltrue(~nan_x)
[docs] def state_lpdf(self, x_curr, x_prev, theta):
"""
Calculates the log-density of `p(x_curr | x_prev, theta)`.
Args:
x_curr: State variable at current time `t`.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
The log-density of `p(x_curr | x_prev, theta)`.
"""
# No need to do tree version since SDE has 2D x_state
# x0 = tree_append_first(
# x=tree_remove_last(x_curr),
# first=tree_keep_last(x_prev)
# )
x0 = jnp.concatenate([x_prev[-1][None], x_curr[:-1]])
x1 = x_curr
lp = jax.vmap(lambda xp, xc:
self.euler_lpdf(
x_curr=xc, x_prev=xp,
dt=self._dt/self._n_res,
theta=theta))(x0, x1)
# x = jnp.append(jnp.expand_dims(x_prev[self._n_res-1], axis=0),
# x_curr, axis=0)
# lp = jax.vmap(lambda t:
# self.euler_lpdf(
# x_curr=x[t+1], x_prev=x[t],
# dt=self._dt/self._n_res,
# theta=theta))(jnp.arange(self._n_res))
return jnp.sum(lp)
# x = jnp.append(jnp.expand_dims(x_prev[self._n_res-1], axis=0),
# x_curr, axis=0)
# return self.euler_lpdf(x, self._dt/self._n_res, theta)
def _state_lpdf_for(self, x_curr, x_prev, theta):
"""
Calculates the log-density of `p(x_curr | x_prev, theta)`.
For-loop version for testing.
Args:
x_curr: State variable at current time `t`.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
The log-density of `p(x_curr | x_prev, theta)`.
"""
dt_res = self._dt/self._n_res
x0 = jnp.append(jnp.expand_dims(
x_prev[self._n_res-1], axis=0), x_curr[:self._n_res-1], axis=0)
x1 = x_curr
lp = jnp.array(0.0)
for t in range(self._n_res):
if self._diff_diag:
lp = lp + jnp.sum(jsp.stats.norm.logpdf(
x=x1[t],
loc=x0[t] + self.drift(x0[t], theta) * dt_res,
scale=self.diff(x0[t], theta) * jnp.sqrt(dt_res)
))
else:
lp = lp + jnp.sum(jsp.stats.multivariate_normal.logpdf(
x=x1[t],
mean=x0[t] + self.drift(x0[t], theta) * dt_res,
cov=self.diff(x0[t], theta) * dt_res
))
return lp
[docs] def state_sample(self, key, x_prev, theta):
"""
Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
Args:
key: PRNG key.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
"""
# lax.scan setup:
# scan function
def fun(carry, t):
key, subkey = random.split(carry["key"])
x = self.euler_sim(
key=subkey, x=carry["x"],
dt=self._dt/self._n_res, theta=theta
)
res = {"x": x, "key": key}
return res, x
# scan initial value
# init = {"x": tree_keep_last(x_prev), "key": key}
init = {"x": x_prev[-1], "key": key}
last, full = lax.scan(fun, init, jnp.arange(self._n_res))
return full
# return self.euler_sim(
# key=key,
# n_steps=self._n_res,
# x=x_prev[self._n_res-1],
# dt=self._dt/self._n_res,
# theta=theta
# )
def _state_sample_for(self, key, x_prev, theta):
"""
Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
For-loop version for testing.
Args:
key: PRNG key.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
"""
dt_res = self._dt/self._n_res
# x_curr = jnp.zeros(self._n_state)
x_curr = []
x_state = x_prev[self._n_res-1]
for t in range(self._n_res):
key, subkey = random.split(key)
if self._diff_diag:
dr = x_state + self.drift(x_state, theta) * dt_res
df = self.diff(x_state, theta) * jnp.sqrt(dt_res)
x_state = dr + df * random.normal(subkey, (x_state.shape[0],))
else:
dr = x_state + self.drift(x_state, theta) * dt_res
df = self.diff(x_state, theta) * dt_res
x_state = random.multivariate_normal(subkey, mean=dr, cov=df)
x_curr.append(x_state)
# x_curr = x_curr.at[t].set(x_state)
return jnp.array(x_curr)
[docs] def pf_step(self, key, x_prev, y_curr, theta):
"""
Update particle and calculate log-weight for a bootstrap particle filter.
**FIXME:** This method is completely generic, i.e., is not specific to SDEs. May wish to put it elsewhere...
Args:
key: PRNG key.
x_prev: State variable at previous time `t-1`.
y_curr: Measurement variable at current time `t`.
theta: Parameter value.
Returns:
- x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
- logw: The log-weight of `x_curr`.
"""
x_curr = self.state_sample(key, x_prev, theta)
logw = lax.cond(
self.is_valid_state(x_curr, theta),
lambda _x: self.meas_lpdf(y_curr, x_curr, theta),
lambda _x: -jnp.inf,
0.0
)
# logw = self.meas_lpdf(y_curr, x_curr, theta)
return x_curr, logw
def _bridge_mv(self, x, theta, n, Y, A, Omega):
r"""
Mean and variance of bridge proposal specific for SDEs.
"""
k = self._n_res - n
dt_res = self._dt / self._n_res
dr = self.drift(x, theta) * dt_res
df = self.diff_full(x, theta) * dt_res
return mb.mvn_bridge_mv(
mu_W=x + dr,
Sigma_W=df,
mu_Y=jnp.matmul(A, x + k*dr),
AS_W=jnp.matmul(A, df),
Sigma_Y=k * jnp.linalg.multi_dot([A, df, A.T]) + Omega,
Y=Y
)
[docs] def bridge_step(self, key, x_prev, y_curr, theta, Y, A, Omega):
"""
Update particle and calculate log-weight for a particle filter with MVN bridge proposals.
**Notes:**
- The measurement input is `Y` instead of `y_meas`. The reason is that the proposal can be used with carefully chosen `Y = g(y_meas)` when the measurement error model is not `y_meas ~ N(A x_state, Omega)`.
- Only implements the general version with arbitrary `A` and `Omega`. Specific cases can be done more rapidly, but it's not clear where to put these different methods. Inside the class? As free functions?
- Duplicates a lot of code from `mvn_bridge_pars()`, because `Sigma_Y` and `mu_Y` inside the latter can be computed very easily here.
- Computes the Euler part of the log-weights using `vmap` after the bridge part which used `lax.scan()`. On one core, it's definitely faster to do both inside `lax.scan()`. On multiple cores that may or may not be the case, but probably would need quite a few cores see the speed increase. However, it's unlikely that we'll explicitly parallelize across cores for this, since the parallellization would typically be over particles.
- The drift and diffusion functions are each calculated twice, once for proposal and once for Euler. This is somewhat inefficient, but to circumvent this would need to redesign `euler_lpdf()`...
Args:
key: PRNG key.
x_prev: State variable at previous time `t-1`.
y_curr: Measurement variable at current time `t`.
theta: Parameter value.
Returns:
Tuple:
- **x_curr** - Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
- **logw** - The log-weight of `x_curr`.
"""
# lax.scan setup
def scan_fun(carry, n):
key = carry["key"]
x = carry["x"]
# calculate mean and variance of bridge proposal
mu_bridge, Sigma_bridge = self._bridge_mv(
x=x, theta=theta, n=n,
Y=Y, A=A, Omega=Omega
)
# bridge proposal
key, subkey = random.split(key)
x_prop = random.multivariate_normal(key,
mean=mu_bridge,
cov=Sigma_bridge)
# bridge log-pdf
lp_prop = jsp.stats.multivariate_normal.logpdf(
x=x_prop,
mean=mu_bridge,
cov=Sigma_bridge
)
res_carry = {
"x": x_prop,
"key": key,
"lp": carry["lp"] + lp_prop
}
res_stack = {"x": x_prop}
return res_carry, res_stack
scan_init = {
"x": x_prev[self._n_res-1],
"key": key,
"lp": jnp.array(0.)
}
last, full = lax.scan(scan_fun, scan_init, jnp.arange(self._n_res))
# return last, full
x_prop = full["x"] # bridge proposal
# log-weight for the proposal
logw = self.state_lpdf(
x_curr=x_prop,
x_prev=x_prev,
theta=theta
)
logw = logw + self.meas_lpdf(y_curr, x_prop, theta) - last["lp"]
logw = lax.cond(
self.is_valid_state(x_prop, theta),
lambda _x: logw,
lambda _x: -jnp.inf,
0.0
)
return x_prop, logw
def _bridge_step_for(self, key, x_prev, y_curr, theta, Y, A, Omega):
"""
For-loop version of bridge_step() for testing.
"""
dt_res = self._dt / self._n_res
x = x_prev[self._n_res - 1]
x_prop = []
lp_prop = jnp.array(0.0)
for t in range(self._n_res):
k = self._n_res - t
dr = self.drift(x, theta) * dt_res
df = self.diff_full(x, theta) * dt_res
mu_W = x + dr
Sigma_W = df
mu_Y = jnp.matmul(A, x + k*dr)
AS_W = jnp.matmul(A, Sigma_W)
Sigma_Y = k * jnp.linalg.multi_dot([A, df, A.T]) + Omega
# solve both linear systems simultaneously
sol = jnp.matmul(AS_W.T, jnp.linalg.solve(
Sigma_Y, jnp.hstack([jnp.array([Y-mu_Y]).T, AS_W])))
mu_bridge = mu_W + jnp.squeeze(sol[:, 0])
Sigma_bridge = Sigma_W - sol[:, 1:]
# bridge proposal
key, subkey = random.split(key)
x = random.multivariate_normal(key,
mean=mu_bridge,
cov=Sigma_bridge)
x_prop.append(x)
# bridge log-pdf
lp_prop = lp_prop + jsp.stats.multivariate_normal.logpdf(
x=x,
mean=mu_bridge,
cov=Sigma_bridge
)
x_prop = jnp.array(x_prop)
logw = self.state_lpdf(
x_curr=x_prop,
x_prev=x_prev,
theta=theta
)
logw = logw + self.meas_lpdf(y_curr, x_prop, theta) - lp_prop
return x_prop, logw