Source code for pfjax.loglik_full

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from jax import random
from jax import lax
from pfjax.utils import *


[docs]def loglik_full(model, y_meas, x_state, theta): """ Calculate the complete data loglikelihood for a state space model. Calculates `p(y_{0:T} | x_{0:T}, theta) * p(x_{0:T} | theta)`. Args: model: Object specifying the state-space model having the following methods: - `state_lpdf : (x_curr, x_prev, theta) -> lpdf`: Calculates the log-density of the state model. - `meas_lpdf : (y_curr, x_curr, theta) -> lpdf`: Calculates the log-density of the measurement model. y_meas: The sequence of `n_obs` measurement variables `y_meas = (y_0, ..., y_T)`, where `T = n_obs-1`. x_state: The sequence of `n_obs` state variables `x_state = (x_0, ..., x_T)`. theta: Parameter value. Returns: The value of the complete data loglikelihood. """ n_obs = y_meas.shape[0] # initial measurement ll_init = model.meas_lpdf(y_curr=y_meas[0], x_curr=x_state[0], theta=theta) # subsequent measurements and state variables ll_step = jax.vmap( lambda xc, xp, yc: model.state_lpdf(x_curr=xc, x_prev=xp, theta=theta) + model.meas_lpdf(y_curr=yc, x_curr=xc, theta=theta) )(tree_remove_first(x_state), tree_remove_last(x_state), tree_remove_first(y_meas)) # ll_step = jax.vmap(lambda t: # model.state_lpdf(x_curr=x_state[t], # x_prev=x_state[t-1], # theta=theta) + # model.meas_lpdf(y_curr=y_meas[t], # x_curr=x_state[t], # theta=theta))(jnp.arange(1, n_obs)) return ll_init + jnp.sum(ll_step)