import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from .base_model import BaseModel
[docs]class BMModel(BaseModel):
r"""
Brownian motion state space model.
The model is:
::
x_0 ~ pi(x_0) \propto 1
x_t ~ N(x_{t-1} + mu dt, sigma^2 dt)
y_t ~ N(x_t, tau^2)
Args:
dt: Interobservation time.
unconstrained_theta: Whether or not to use the regular parameters scale `theta = (mu, sigma, tau)` or the unconstrained scale `theta = (mu, log(sigma), log(tau))`.
"""
def __init__(self, dt, unconstrained_theta=False):
super().__init__(bootstrap=True)
self._dt = dt
self._unconstrained_theta = unconstrained_theta
def _constrain_theta(self, theta):
r"""
Convert `theta` to the constrained scale.
"""
return jnp.array([theta[0], jnp.exp(theta[1]), jnp.exp(theta[2])])
[docs] def state_lpdf(self, x_curr, x_prev, theta):
r"""
Calculates the log-density of `p(x_curr | x_prev, theta)`.
Args:
x_curr: State variable at current time `t`.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
The log-density of `p(x_curr | x_prev, theta)`.
"""
if self._unconstrained_theta:
theta = self._constrain_theta(theta)
mu = theta[0]
sigma = theta[1]
return jnp.squeeze(
jsp.stats.norm.logpdf(x_curr, loc=x_prev + mu * self._dt,
scale=sigma * jnp.sqrt(self._dt))
)
[docs] def state_sample(self, key, x_prev, theta):
r"""
Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
Args:
key: PRNG key.
x_prev: State variable at previous time `t-1`.
theta: Parameter value.
Returns:
Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
"""
if self._unconstrained_theta:
theta = self._constrain_theta(theta)
mu = theta[0]
sigma = theta[1]
x_mean = x_prev + mu * self._dt
x_sd = sigma * jnp.sqrt(self._dt)
return x_mean + x_sd * random.normal(key=key)
[docs] def meas_lpdf(self, y_curr, x_curr, theta):
r"""
Log-density of `p(y_curr | x_curr, theta)`.
Args:
y_curr: Measurement variable at current time `t`.
x_curr: State variable at current time `t`.
theta: Parameter value.
Returns:
The log-density of `p(y_curr | x_curr, theta)`.
"""
if self._unconstrained_theta:
theta = self._constrain_theta(theta)
tau = theta[2]
return jnp.squeeze(
jsp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)
)
[docs] def meas_sample(self, key, x_curr, theta):
r"""
Sample from `p(y_curr | x_curr, theta)`.
Args:
key: PRNG key.
x_curr: State variable at current time `t`.
theta: Parameter value.
Returns:
Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
"""
if self._unconstrained_theta:
theta = self._constrain_theta(theta)
tau = theta[2]
return x_curr + tau * random.normal(key=key)
[docs] def pf_init(self, key, y_init, theta):
r"""
Get particles for initial state `x_init = x_state[0]`.
Samples from an importance sampling proposal distribution
::
x_init ~ q(x_init) = q(x_init | y_init, theta)
and calculates the log weight
::
logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
In this case we have an exact proposal
::
q(x_init) = p(x_init | y_init, theta)
<=> x_init ~ N(y_init, tau)
Moreover, due to symmetry of arguments we have `q(x_init) = p(y_init | x_init, theta)`, and since `p(x_init | theta) \propto 1` we have `logw = 0`.
Args:
key: PRNG key.
y_init: Measurement variable at initial time `t = 0`.
theta: Parameter value.
Returns:
Tuple:
- **x_init** - A sample from the proposal distribution for `x_init`.
- **logw** - The log-weight of `x_init`.
"""
return self.meas_sample(key, y_init, theta), jnp.zeros(())
[docs] def loglik_exact(self, y_meas, theta):
r"""
Marginal loglikelihood of the BM model.
Actually calculates `log p(y_{1:N} | theta, y_0)`, since for the flat prior on `x_0` the marginal likelihood `p(y_{0:N} | theta)` does not exist.
Args:
y_meas: Vector of observations `y_0, ..., y_N`.
theta: Parameter value.
Returns:
float:
The marginal loglikelihood `log p(y_{1:N} | theta, y_0)`.
"""
if self._unconstrained_theta:
theta = self._constrain_theta(theta)
mu = theta[0]
sigma2 = theta[1] * theta[1]
tau2 = theta[2] * theta[2]
n_obs = y_meas.shape[0]-1 # conditioning on y_0
t_meas = jnp.arange(1, n_obs+1) * self._dt
Sigma_y = sigma2 * jax.vmap(lambda t:
jnp.minimum(t, t_meas))(t_meas) + \
tau2 * (jnp.ones((n_obs, n_obs)) + jnp.eye(n_obs))
mu_y = y_meas[0] + mu * t_meas
return jsp.stats.multivariate_normal.logpdf(
x=jnp.squeeze(y_meas[1:]),
mean=mu_y,
cov=Sigma_y
)
# class BMModel:
# def __init__(self, dt):
# self.n_state = ()
# self.n_meas = ()
# self._dt = dt
# def state_lpdf(self, x_curr, x_prev, theta):
# r"""
# Calculates the log-density of `p(x_curr | x_prev, theta)`.
# Args:
# x_curr: State variable at current time `t`.
# x_prev: State variable at previous time `t-1`.
# theta: Parameter value.
# Returns:
# The log-density of `p(x_curr | x_prev, theta)`.
# """
# mu = theta[0]
# sigma = theta[1]
# return jnp.squeeze(
# jsp.stats.norm.logpdf(x_curr, loc=x_prev + mu * self._dt,
# scale=sigma * jnp.sqrt(self._dt))
# )
# def state_sample(self, key, x_prev, theta):
# r"""
# Samples from `x_curr ~ p(x_curr | x_prev, theta)`.
# Args:
# key: PRNG key.
# x_prev: State variable at previous time `t-1`.
# theta: Parameter value.
# Returns:
# Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
# """
# mu = theta[0]
# sigma = theta[1]
# x_mean = x_prev + mu * self._dt
# x_sd = sigma * jnp.sqrt(self._dt)
# return x_mean + x_sd * random.normal(key=key)
# def meas_lpdf(self, y_curr, x_curr, theta):
# r"""
# Log-density of `p(y_curr | x_curr, theta)`.
# Args:
# y_curr: Measurement variable at current time `t`.
# x_curr: State variable at current time `t`.
# theta: Parameter value.
# Returns:
# The log-density of `p(y_curr | x_curr, theta)`.
# """
# tau = theta[2]
# return jnp.squeeze(
# jsp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)
# )
# def meas_sample(self, key, x_curr, theta):
# r"""
# Sample from `p(y_curr | x_curr, theta)`.
# Args:
# key: PRNG key.
# x_curr: State variable at current time `t`.
# theta: Parameter value.
# Returns:
# Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
# """
# tau = theta[2]
# return x_curr + tau * random.normal(key=key)
# def init_logw(self, x_init, y_init, theta):
# r"""
# Log-weight of the importance sampler for initial state variable `x_init`.
# Suppose that
# ```
# x_init ~ q(x_init) = q(x_init | y_init, theta)
# ```
# Then function returns
# ```
# logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
# ```
# Args:
# x_init: State variable at initial time `t = 0`.
# y_init: Measurement variable at initial time `t = 0`.
# theta: Parameter value.
# Returns:
# The log-weight of the importance sampler for `x_init`.
# """
# # return -meas_lpdf(x_init, y_init, theta)
# return jnp.zeros(())
# def init_sample(self, key, y_init, theta):
# r"""
# Sampling distribution for initial state variable `x_init`.
# Samples from an importance sampling proposal distribution
# ```
# x_init ~ q(x_init) = q(x_init | y_init, theta)
# ```
# See `init_logw()` for details.
# Args:
# key: PRNG key.
# y_init: Measurement variable at initial time `t = 0`.
# theta: Parameter value.
# Returns:
# Sample from the proposal distribution for `x_init`.
# """
# return self.meas_sample(key, y_init, theta)
# def pf_init(self, key, y_init, theta):
# r"""
# Get particles for initial state `x_init = x_state[0]`.
# Samples from an importance sampling proposal distribution
# ```
# x_init ~ q(x_init) = q(x_init | y_init, theta)
# ```
# and calculates the log weight
# ```
# logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
# ```
# In this case we have an exact proposal
# ```
# q(x_init) = p(x_init | y_init, theta) <=> x_init ~ N(y_init, tau)
# ```
# Moreover, due to symmetry of arguments we have `q(x_init) = p(y_init | x_init, theta)`, and since `p(x_init | theta) \propto 1` we have `logw = 0`.
# Args:
# key: PRNG key.
# y_init: Measurement variable at initial time `t = 0`.
# theta: Parameter value.
# Returns:
# - x_init: A sample from the proposal distribution for `x_init`.
# - logw: The log-weight of `x_init`.
# """
# return self.meas_sample(key, y_init, theta), jnp.zeros(())
# def pf_step(self, key, x_prev, y_curr, theta):
# r"""
# Get particles at current step `t_curr` from previous step `t_prev`.
# Samples from an importance sampling proposal distribution
# ```
# x_curr ~ q(x_curr) = q(x_curr | x_prev, y_curr, theta)
# ```
# and calculates the log weight
# ```
# logw = log p(y_curr | x_curr, theta) + log p(x_curr | x_prev, theta) - log q(x_curr)
# ```
# In this case we have a bootstrap particle filter, so `q(x_curr) = p(x_curr | x_prev, theta)` and `logw = log p(y_curr | x_curr, theta)`.
# Args:
# key: PRNG key.
# x_prev: State variable at previous time `t-1`.
# y_curr: Measurement variable at current time `t`.
# theta: Parameter value.
# Returns:
# - x_curr: Sample of the state variable at current time `t`: `x_curr ~ q(x_curr)`.
# - logw: The log-weight of `x_curr`.
# """
# x_curr = self.state_sample(key, x_prev, theta)
# logw = self.meas_lpdf(y_curr, x_curr, theta)
# return x_curr, logw
# def prop_lpdf(self, x_curr, x_prev, y_curr, theta):
# r"""
# Calculate the log-density of the proposal distribution `q(x_curr) = q(x_curr | x_prev, y_curr, theta)`.
# In this case we have a bootstrap filter, so `q(x_curr) = p(x_curr | x_prev, theta)`.
# Args:
# x_curr: State variable at current time `t`.
# x_prev: State variable at previous time `t-1`.
# y_curr: Measurement variable at current time `t`.
# theta: Parameter value.
# """
# return self.state_lpdf(x_curr=x_curr, x_prev=x_prev, theta=theta)
# def loglik_exact(self, y_meas, theta):
# r"""
# Marginal loglikelihood of the BM model.
# Actually calculates `log p(y_{1:N} | theta, y_0)`, since for the flat prior on `x_0` the marginal likelihood `p(y_{0:N} | theta)` does not exist.
# Args:
# y_meas: Vector of observations `y_0, ..., y_N`.
# theta: Parameter value.
# Returns:
# The marginal loglikelihood `log p(y_{1:N} | theta, y_0)`.
# """
# mu = theta[0]
# sigma2 = theta[1] * theta[1]
# tau2 = theta[2] * theta[2]
# n_obs = y_meas.shape[0]-1 # conditioning on y_0
# t_meas = jnp.arange(1, n_obs+1) * self._dt
# Sigma_y = sigma2 * jax.vmap(lambda t:
# jnp.minimum(t, t_meas))(t_meas) + \
# tau2 * (jnp.ones((n_obs, n_obs)) + jnp.eye(n_obs))
# mu_y = y_meas[0] + mu * t_meas
# return jsp.stats.multivariate_normal.logpdf(
# x=jnp.squeeze(y_meas[1:]),
# mean=mu_y,
# cov=Sigma_y
# )