pfjax.particle_filter
Particle filters which approximate the score and fisher information.
Module Contents
Functions
|
Basic particle filter. |
|
Rao-Blackwellized particle filter. |
- pfjax.particle_filter.particle_filter(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]
Basic particle filter.
Notes:
Can optionally use an auxiliary particle filter if model.pf_aux() is provided.
The score and fisher information are estimated using the method described by Poyiadjis et al 2011, Algorithm 1.
Should have the option of adding new measurements as we go along. So for example, could have an argument init=None, which if not None is the carry from lax.scan(). Should then also return the carry as an output…
- Parameters
model –
Object specifying the state-space model having the following methods:
pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.
pf_step : (key, x_prev, y_curr, theta) -> (x_particles, logw): For sampling and calculating log-weights for each subsequent latent variable.
pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.
state_lpdf : (x_curr, x_prev, theta) -> lpdf: Optional method specifying the log-density of the state model. Only required if score or fisher == True.
meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Optional method specifying the log-density of the measurement model. Only required if score or fisher == True.
key – PRNG key.
y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
theta – Parameter value.
n_particles – Number of particles.
resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().
score – Whether or not to return an estimate of the score function at theta. Only works if resampler has an output element named ancestors.
fisher – Whether or not to return an estimate of the Fisher information at theta. Only works if resampler has an output element named ancestors. If True returns score as well.
history – Whether to output the history of the filter or only the last step.
- Returns
x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.
logw - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.
loglik - The particle filter loglikelihood evaluated at theta.
score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.
fisher - Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.
resample_out - If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.
- Return type
Dictionary
- pfjax.particle_filter.particle_filter_rb(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]
Rao-Blackwellized particle filter.
Notes
Algorithm 2 of Poyiadjis et al 2011.
Can optionally use an auxiliary particle filter if model.pf_aux() is provided.
- Parameters
model –
Object specifying the state-space model having the following methods:
pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.
step_sample : (key, x_prev, y_curr, theta) -> x_curr: Sampling from the proposal distribution for each subsequent latent variable.
step_lpdf : (x_curr, x_prev, y_curr, theta) -> logw: Calculate log-weights for each subsequent latent variable.
state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.
meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.
pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.
key – PRNG key.
y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
theta – Parameter value.
n_particles – Number of particles.
resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().
score – Whether or not to return an estimate of the score function at theta.
fisher – Whether or not to return an estimate of the Fisher information at theta. If True returns score as well.
history – Whether to output the history of the filter or only the last step.
- Returns
x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.
logw_bar - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.
loglik - The particle filter loglikelihood evaluated at theta.
score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.
fisher Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.
resample_out If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.
- Return type
Dictionary