pfjax

Subpackages

Submodules

Package Contents

Classes

BaseModel

Base model for particle filters.

Functions

simulate(model, key, n_obs, x_init, theta)

Simulate data from the state-space model.

loglik_full(model, y_meas, x_state, theta)

Calculate the complete data loglikelihood for a state space model.

particle_filter(model, key, y_meas, theta, n_particles)

Basic particle filter.

particle_filter_rb(model, key, y_meas, theta, n_particles)

Rao-Blackwellized particle filter.

particle_smooth

Attributes

__version__

__author__

pfjax.simulate(model, key, n_obs, x_init, theta)[source]

Simulate data from the state-space model.

Parameters
  • model

    Object specifying the state-space model having the following methods.

    • state_sim : (key, x_prev, theta) -> x_curr: Sample from the state model.

    • meas_sim : (key, x_curr, theta) -> y_curr: Sample from the measurement model.

  • key – PRNG key.

  • n_obs – Number of observations to generate.

  • x_init – Initial state value at time t = 0.

  • theta – Parameter value.

Returns

  • y_meas - The sequence of measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • x_state - The sequence of state variables x_state = (x_0, …, x_T), where T = n_obs-1.

Return type

Tuple

pfjax.loglik_full(model, y_meas, x_state, theta)[source]

Calculate the complete data loglikelihood for a state space model.

Calculates p(y_{0:T} | x_{0:T}, theta) * p(x_{0:T} | theta).

Parameters
  • model

    Object specifying the state-space model having the following methods:

    • state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.

    • meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.

  • y_meas – The sequence of n_obs measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • x_state – The sequence of n_obs state variables x_state = (x_0, …, x_T).

  • theta – Parameter value.

Returns

The value of the complete data loglikelihood.

pfjax.particle_filter(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]

Basic particle filter.

Notes:

  • Can optionally use an auxiliary particle filter if model.pf_aux() is provided.

  • The score and fisher information are estimated using the method described by Poyiadjis et al 2011, Algorithm 1.

  • Should have the option of adding new measurements as we go along. So for example, could have an argument init=None, which if not None is the carry from lax.scan(). Should then also return the carry as an output…

Parameters
  • model

    Object specifying the state-space model having the following methods:

    • pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.

    • pf_step : (key, x_prev, y_curr, theta) -> (x_particles, logw): For sampling and calculating log-weights for each subsequent latent variable.

    • pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.

    • state_lpdf : (x_curr, x_prev, theta) -> lpdf: Optional method specifying the log-density of the state model. Only required if score or fisher == True.

    • meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Optional method specifying the log-density of the measurement model. Only required if score or fisher == True.

  • key – PRNG key.

  • y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • theta – Parameter value.

  • n_particles – Number of particles.

  • resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().

  • score – Whether or not to return an estimate of the score function at theta. Only works if resampler has an output element named ancestors.

  • fisher – Whether or not to return an estimate of the Fisher information at theta. Only works if resampler has an output element named ancestors. If True returns score as well.

  • history – Whether to output the history of the filter or only the last step.

Returns

  • x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.

  • logw - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.

  • loglik - The particle filter loglikelihood evaluated at theta.

  • score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.

  • fisher - Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.

  • resample_out - If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.

Return type

Dictionary

pfjax.particle_filter_rb(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]

Rao-Blackwellized particle filter.

Notes

  • Algorithm 2 of Poyiadjis et al 2011.

  • Can optionally use an auxiliary particle filter if model.pf_aux() is provided.

Parameters
  • model

    Object specifying the state-space model having the following methods:

    • pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.

    • step_sample : (key, x_prev, y_curr, theta) -> x_curr: Sampling from the proposal distribution for each subsequent latent variable.

    • step_lpdf : (x_curr, x_prev, y_curr, theta) -> logw: Calculate log-weights for each subsequent latent variable.

    • state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.

    • meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.

    • pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.

  • key – PRNG key.

  • y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • theta – Parameter value.

  • n_particles – Number of particles.

  • resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().

  • score – Whether or not to return an estimate of the score function at theta.

  • fisher – Whether or not to return an estimate of the Fisher information at theta. If True returns score as well.

  • history – Whether to output the history of the filter or only the last step.

Returns

  • x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.

  • logw_bar - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.

  • loglik - The particle filter loglikelihood evaluated at theta.

  • score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.

  • fisher Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.

  • resample_out If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.

Return type

Dictionary

pfjax.particle_smooth(key, logw, x_particles, ancestors)[source]

Draw a sample from p(x_state | x_meas, theta) from a particle filter with multinomial resampling.

Parameters
  • key – PRNG key.

  • logw – Vector of n_particles unnormalized log-weights at the last time point t = n_obs-1.

  • x_particles – JAX array with leading dimensions (n_obs, n_particles) containing the state variable particles.

  • ancestors – JAX integer array of shape (n_obs-1, n_particles) where each element gives the index of the particle’s ancestor at the previous time point.

Returns

An array with leading dimension n_obs sampled from p(x_{0:T} | y_{0:T}, theta).

Return type

JAX Array

class pfjax.BaseModel(bootstrap)[source]

Bases: object

Base model for particle filters.

This class sets up a PF model from small set of methods:

  • The derived class should provide methods state_lpdf(), meas_lpdf(), state_sample() and meas_sample() in order calculate the complete data likelihood pfjax.loglik_full() and simulate data from the model via pfjax.simulate().

  • To use the “basic” particle filter pfjax.particle_filter(), the derived class must provide methods pf_init() and pf_step().

  • To use the Rao-Blackwellized particle filter pfjax.particle_filter_rb(), the derived class must provide methods pf_init(), step_sample(), and step_lpdf().

  • If pf_init() is missing, the base class will automatically construct it from prior_lpdf(), init_sample(), and init_lpdf().

  • if pf_step() is missing, the base class will automatically construct it from state_lpdf(), meas_lpdf(), step_sample(), and step_lpdf().

  • If in either of the above step_sample() and step_lpdf() are missing, the base class assumes a bootstrap particle filter and sets step_sample = state_sample and step_lpdf = state_lpdf.

  • If in either of the above init_sample() and init_lpdf() are missing, the base calss sets init_sample = prior_sample and init_lpdf = prior_lpdf.

Notes:

  • pf_step() and pf_init() could be computed more efficiently for bootstrap sampling. Perhaps this could be specified with an argument bootstrap to the constructor.

  • In general, nothing is stopping the user from creating e.g., pf_step() which is inconsistent with step_sample(), etc.

Parameters

bootstrap – Boolean for whether or not to create a bootstrap particle filter.

step_sample(key, x_prev, y_curr, theta)[source]

Sample from default proposal distribution

q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
Parameters
  • key – PRNG key.

  • x_prev – State variable at previous time t-1.

  • y_curr – Measurement variable at current time t.

  • theta – Parameter value.

Returns

Sample of the state variable x_curr at current time t.

step_lpdf(x_curr, x_prev, y_curr, theta)[source]

Calculate log-density of the default proposal distribution

q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
Parameters
  • x_curr – State variable at current time t.

  • x_prev – State variable at previous time t-1.

  • y_curr – Measurement variable at current time t.

  • theta – Parameter value.

Returns

Log-density of the state variable x_curr at current time t.

init_sample(key, y_init, theta)[source]

Sample from default initial proposal distribution

q(x_init | y_init, theta) = p(x_init | theta)
Parameters
  • key – PRNG key.

  • y_init – Measurement variable at initial time t = 0.

  • theta – Parameter value.

Returns

Sample of the state variable x_init at initial time t = 0.

init_lpdf(x_init, y_init, theta)[source]

Calculate log-density of the default proposal distribution

q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
Parameters
  • x_init – State variable at initial time t = 0.

  • y_init – Measurement variable at initial time t = 0.

  • theta – Parameter value.

Returns

Log-density of the state variable x_init at initial time t = 0.

pf_step(key, x_prev, y_curr, theta)[source]

Particle filter update.

Returns a draw from proposal distribution

x_curr ~ q(x_curr) = q(x_curr | x_prev, y_curr, theta)

and the log weight

logw = log p(y_curr | x_curr, theta) + log p(x_curr | x_prev, theta) - log q(x_curr)
Parameters
  • key – PRNG key.

  • x_prev – State variable at previous time t-1.

  • y_curr – Measurement variable at current time t.

  • theta – Parameter value.

Returns

  • x_curr - A sample from the proposal distribution at current time t.

  • logw - The log-weight of x_curr.

Return type

Tuple

pf_init(key, y_init, theta)[source]

Initial step of particle filter.

Returns a draw from the proposal distribution

x_init ~ q(x_init) = q(x_init | y_init, theta)

and calculates the log weight

logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
Parameters
  • key – PRNG key.

  • y_init – Measurement variable at initial time t = 0.

  • theta – Parameter value.

Returns

  • x_init - A sample from the proposal distribution at initial tme t = 0.

  • logw - The log-weight of x_init.

Return type

Tuple

pfjax.__version__ = 0.0.2
pfjax.__author__ = Martin Lysy, Pranav Subramani, Jonathan Ramkissoon, Mohan Wu, Michelle Ko, Kanika Choptra