pfjax
Subpackages
Submodules
Package Contents
Classes
Base model for particle filters. |
Functions
|
Simulate data from the state-space model. |
|
Calculate the complete data loglikelihood for a state space model. |
|
Basic particle filter. |
|
Rao-Blackwellized particle filter. |
Attributes
- pfjax.simulate(model, key, n_obs, x_init, theta)[source]
Simulate data from the state-space model.
- Parameters
model –
Object specifying the state-space model having the following methods.
state_sim : (key, x_prev, theta) -> x_curr: Sample from the state model.
meas_sim : (key, x_curr, theta) -> y_curr: Sample from the measurement model.
key – PRNG key.
n_obs – Number of observations to generate.
x_init – Initial state value at time t = 0.
theta – Parameter value.
- Returns
y_meas - The sequence of measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
x_state - The sequence of state variables x_state = (x_0, …, x_T), where T = n_obs-1.
- Return type
Tuple
- pfjax.loglik_full(model, y_meas, x_state, theta)[source]
Calculate the complete data loglikelihood for a state space model.
Calculates p(y_{0:T} | x_{0:T}, theta) * p(x_{0:T} | theta).
- Parameters
model –
Object specifying the state-space model having the following methods:
state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.
meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.
y_meas – The sequence of n_obs measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
x_state – The sequence of n_obs state variables x_state = (x_0, …, x_T).
theta – Parameter value.
- Returns
The value of the complete data loglikelihood.
- pfjax.particle_filter(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]
Basic particle filter.
Notes:
Can optionally use an auxiliary particle filter if model.pf_aux() is provided.
The score and fisher information are estimated using the method described by Poyiadjis et al 2011, Algorithm 1.
Should have the option of adding new measurements as we go along. So for example, could have an argument init=None, which if not None is the carry from lax.scan(). Should then also return the carry as an output…
- Parameters
model –
Object specifying the state-space model having the following methods:
pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.
pf_step : (key, x_prev, y_curr, theta) -> (x_particles, logw): For sampling and calculating log-weights for each subsequent latent variable.
pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.
state_lpdf : (x_curr, x_prev, theta) -> lpdf: Optional method specifying the log-density of the state model. Only required if score or fisher == True.
meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Optional method specifying the log-density of the measurement model. Only required if score or fisher == True.
key – PRNG key.
y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
theta – Parameter value.
n_particles – Number of particles.
resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().
score – Whether or not to return an estimate of the score function at theta. Only works if resampler has an output element named ancestors.
fisher – Whether or not to return an estimate of the Fisher information at theta. Only works if resampler has an output element named ancestors. If True returns score as well.
history – Whether to output the history of the filter or only the last step.
- Returns
x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.
logw - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.
loglik - The particle filter loglikelihood evaluated at theta.
score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.
fisher - Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.
resample_out - If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.
- Return type
Dictionary
- pfjax.particle_filter_rb(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False)[source]
Rao-Blackwellized particle filter.
Notes
Algorithm 2 of Poyiadjis et al 2011.
Can optionally use an auxiliary particle filter if model.pf_aux() is provided.
- Parameters
model –
Object specifying the state-space model having the following methods:
pf_init : (key, y_init, theta) -> (x_particles, logw): For sampling and calculating log-weights for the initial latent variable.
step_sample : (key, x_prev, y_curr, theta) -> x_curr: Sampling from the proposal distribution for each subsequent latent variable.
step_lpdf : (x_curr, x_prev, y_curr, theta) -> logw: Calculate log-weights for each subsequent latent variable.
state_lpdf : (x_curr, x_prev, theta) -> lpdf: Calculates the log-density of the state model.
meas_lpdf : (y_curr, x_curr, theta) -> lpdf: Calculates the log-density of the measurement model.
pf_aux : (x_prev, y_curr, theta) -> logw: Optional method providing look-forward log-weights of the auxillary particle filter.
key – PRNG key.
y_meas – JAX array with leading dimension n_obs containing the measurement variables y_meas = (y_0, …, y_T), where T = n_obs-1.
theta – Parameter value.
n_particles – Number of particles.
resampler – Function used at step t to obtain sample of particles from p(x_{t} | y_{0:t}, theta) out of a sample of particles from p(x_{t-1} | y_{0:t-1}, theta). The argument signature is resampler(x_particles, logw, key), and the return value is a dictionary with mandatory element x_particles and optional elements that get carried to the next step t+1 via lax.scan().
score – Whether or not to return an estimate of the score function at theta.
fisher – Whether or not to return an estimate of the Fisher information at theta. If True returns score as well.
history – Whether to output the history of the filter or only the last step.
- Returns
x_particles - JAX array containing the state variable particles at the last time point (leading dimension n_particles) or at all time points (leading dimensions (n_obs, n_particles) if history=True.
logw_bar - JAX array containing unnormalized log weights at the last time point (dimensions n_particles) or at all time points (dimensions (n_obs, n_particles)`) if history=True.
loglik - The particle filter loglikelihood evaluated at theta.
score - Optional 1D JAX array of size n_theta = length(theta) containing the estimated score at theta.
fisher Optional JAX array of shape (n_theta, n_theta) containing the estimated observed fisher information at theta.
resample_out If history=True, a dictionary of additional outputs from resampler function. The leading dimension of each element of the dictionary has leading dimension n_obs-1, since these additional outputs do not apply to the first time point t=0.
- Return type
Dictionary
- pfjax.particle_smooth(key, logw, x_particles, ancestors)[source]
Draw a sample from p(x_state | x_meas, theta) from a particle filter with multinomial resampling.
- Parameters
key – PRNG key.
logw – Vector of n_particles unnormalized log-weights at the last time point t = n_obs-1.
x_particles – JAX array with leading dimensions (n_obs, n_particles) containing the state variable particles.
ancestors – JAX integer array of shape (n_obs-1, n_particles) where each element gives the index of the particle’s ancestor at the previous time point.
- Returns
An array with leading dimension n_obs sampled from p(x_{0:T} | y_{0:T}, theta).
- Return type
JAX Array
- class pfjax.BaseModel(bootstrap)[source]
Bases:
objectBase model for particle filters.
This class sets up a PF model from small set of methods:
The derived class should provide methods state_lpdf(), meas_lpdf(), state_sample() and meas_sample() in order calculate the complete data likelihood pfjax.loglik_full() and simulate data from the model via pfjax.simulate().
To use the “basic” particle filter pfjax.particle_filter(), the derived class must provide methods pf_init() and pf_step().
To use the Rao-Blackwellized particle filter pfjax.particle_filter_rb(), the derived class must provide methods pf_init(), step_sample(), and step_lpdf().
If pf_init() is missing, the base class will automatically construct it from prior_lpdf(), init_sample(), and init_lpdf().
if pf_step() is missing, the base class will automatically construct it from state_lpdf(), meas_lpdf(), step_sample(), and step_lpdf().
If in either of the above step_sample() and step_lpdf() are missing, the base class assumes a bootstrap particle filter and sets step_sample = state_sample and step_lpdf = state_lpdf.
If in either of the above init_sample() and init_lpdf() are missing, the base calss sets init_sample = prior_sample and init_lpdf = prior_lpdf.
Notes:
pf_step() and pf_init() could be computed more efficiently for bootstrap sampling. Perhaps this could be specified with an argument bootstrap to the constructor.
In general, nothing is stopping the user from creating e.g., pf_step() which is inconsistent with step_sample(), etc.
- Parameters
bootstrap – Boolean for whether or not to create a bootstrap particle filter.
- step_sample(key, x_prev, y_curr, theta)[source]
Sample from default proposal distribution
q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
- Parameters
key – PRNG key.
x_prev – State variable at previous time t-1.
y_curr – Measurement variable at current time t.
theta – Parameter value.
- Returns
Sample of the state variable x_curr at current time t.
- step_lpdf(x_curr, x_prev, y_curr, theta)[source]
Calculate log-density of the default proposal distribution
q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
- Parameters
x_curr – State variable at current time t.
x_prev – State variable at previous time t-1.
y_curr – Measurement variable at current time t.
theta – Parameter value.
- Returns
Log-density of the state variable x_curr at current time t.
- init_sample(key, y_init, theta)[source]
Sample from default initial proposal distribution
q(x_init | y_init, theta) = p(x_init | theta)
- Parameters
key – PRNG key.
y_init – Measurement variable at initial time t = 0.
theta – Parameter value.
- Returns
Sample of the state variable x_init at initial time t = 0.
- init_lpdf(x_init, y_init, theta)[source]
Calculate log-density of the default proposal distribution
q(x_curr | x_prev, y_curr, theta) = p(x_curr | x_prev, theta)
- Parameters
x_init – State variable at initial time t = 0.
y_init – Measurement variable at initial time t = 0.
theta – Parameter value.
- Returns
Log-density of the state variable x_init at initial time t = 0.
- pf_step(key, x_prev, y_curr, theta)[source]
Particle filter update.
Returns a draw from proposal distribution
x_curr ~ q(x_curr) = q(x_curr | x_prev, y_curr, theta)
and the log weight
logw = log p(y_curr | x_curr, theta) + log p(x_curr | x_prev, theta) - log q(x_curr)
- Parameters
key – PRNG key.
x_prev – State variable at previous time t-1.
y_curr – Measurement variable at current time t.
theta – Parameter value.
- Returns
x_curr - A sample from the proposal distribution at current time t.
logw - The log-weight of x_curr.
- Return type
Tuple
- pf_init(key, y_init, theta)[source]
Initial step of particle filter.
Returns a draw from the proposal distribution
x_init ~ q(x_init) = q(x_init | y_init, theta)
and calculates the log weight
logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)
- Parameters
key – PRNG key.
y_init – Measurement variable at initial time t = 0.
theta – Parameter value.
- Returns
x_init - A sample from the proposal distribution at initial tme t = 0.
logw - The log-weight of x_init.
- Return type
Tuple
- pfjax.__version__ = 0.0.2
- pfjax.__author__ = Martin Lysy, Pranav Subramani, Jonathan Ramkissoon, Mohan Wu, Michelle Ko, Kanika Choptra