pfjax.particle_filter

Particle filters which approximate the score and fisher information.

Module Contents

Functions

particle_filter(model, key, y_meas, theta, n_particles)

Basic particle filter.

particle_filter_rb(model, key, y_meas, theta, n_particles)

Rao-Blackwellized particle filter.

pfjax.particle_filter.particle_filter(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]

Basic particle filter.

Notes:

  • Can optionally use an auxiliary particle filter if model.pf_aux() is provided.

  • The score and fisher information are estimated using the method described by Poyiadjis et al 2011, Algorithm 1.

  • Should have the option of adding new measurements as we go along. So for example, could have an argument init=None, which if not None is the carry from lax.scan(). Should then also return the carry as an output…

Parameters
  • model

    Object specifying the state-space model having the following methods:

    • pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.

    • pf_step : (key, x_prev, y_curr, theta) -> (x_particles, logw): For sampling and calculating log-weights for each subsequent latent variable.

    • pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.

    • state_lpdf : (x_curr, x_prev, theta) -> lpdf: Optional method specifying the log-density of the state model. Only required if score or fisher == True.

    • meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Optional method specifying the log-density of the measurement model. Only required if score or fisher == True.

  • key – PRNG key.

  • y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • theta – Parameter value.

  • n_particles – Number of particles.

  • resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().

  • score – Whether or not to return an estimate of the score function at theta. Only works if resampler has an output element named ancestors.

  • fisher – Whether or not to return an estimate of the Fisher information at theta. Only works if resampler has an output element named ancestors. If True returns score as well.

  • history – Whether to output the history of the filter or only the last step.

Returns

  • x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.

  • logw - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.

  • loglik - The particle filter loglikelihood evaluated at theta.

  • score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.

  • fisher - Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.

  • resample_out - If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.

Return type

Dictionary

pfjax.particle_filter.particle_filter_rb(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]

Rao-Blackwellized particle filter.

Notes

  • Algorithm 2 of Poyiadjis et al 2011.

  • Can optionally use an auxiliary particle filter if model.pf_aux() is provided.

Parameters
  • model

    Object specifying the state-space model having the following methods:

    • pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.

    • step_sample : (key, x_prev, y_curr, theta) -> x_curr: Sampling from the proposal distribution for each subsequent latent variable.

    • step_lpdf : (x_curr, x_prev, y_curr, theta) -> logw: Calculate log-weights for each subsequent latent variable.

    • state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.

    • meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.

    • pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.

  • key – PRNG key.

  • y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.

  • theta – Parameter value.

  • n_particles – Number of particles.

  • resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().

  • score – Whether or not to return an estimate of the score function at theta.

  • fisher – Whether or not to return an estimate of the Fisher information at theta. If True returns score as well.

  • history – Whether to output the history of the filter or only the last step.

Returns

  • x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.

  • logw_bar - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.

  • loglik - The particle filter loglikelihood evaluated at theta.

  • score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.

  • fisher Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.

  • resample_out If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.

Return type

Dictionary