Source code for pfjax.particle_filter

"""
Particle filters which approximate the score and fisher information.
"""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from jax import random
from jax import lax
# from jax.experimental.host_callback import id_print
from .utils import *
from .particle_resamplers import resample_multinomial


[docs]def particle_filter(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False): r""" Basic particle filter. **Notes:** - Can optionally use an auxiliary particle filter if `model.pf_aux()` is provided. - The score and fisher information are estimated using the method described by Poyiadjis et al 2011, Algorithm 1. - Should have the option of adding new measurements as we go along. So for example, could have an argument `init=None`, which if not `None` is the carry from `lax.scan()`. Should then also return the carry as an output... Args: model: Object specifying the state-space model having the following methods: - `pf_init : (key, y_init, theta) -> (x_particles, logw)`: For sampling and calculating log-weights for the initial latent variable. - `pf_step : (key, x_prev, y_curr, theta) -> (x_particles, logw)`: For sampling and calculating log-weights for each subsequent latent variable. - `pf_aux : (x_prev, y_curr, theta) -> logw`: Optional method providing look-forward log-weights of the auxillary particle filter. - `state_lpdf : (x_curr, x_prev, theta) -> lpdf`: Optional method specifying the log-density of the state model. Only required if `score or fisher == True`. - `meas_lpdf : (y_curr, x_curr, theta) -> lpdf`: Optional method specifying the log-density of the measurement model. Only required if `score or fisher == True`. key: PRNG key. y_meas: JAX array with leading dimension `n_obs` containing the measurement variables `y_meas = (y_0, ..., y_T)`, where `T = n_obs-1`. theta: Parameter value. n_particles: Number of particles. resampler: Function used at step `t` to obtain sample of particles from `p(x_{t} | y_{0:t}, theta)` out of a sample of particles from `p(x_{t-1} | y_{0:t-1}, theta)`. The argument signature is `resampler(x_particles, logw, key)`, and the return value is a dictionary with mandatory element `x_particles` and optional elements that get carried to the next step `t+1` via `lax.scan()`. score: Whether or not to return an estimate of the score function at `theta`. Only works if `resampler` has an output element named `ancestors`. fisher: Whether or not to return an estimate of the Fisher information at `theta`. Only works if `resampler` has an output element named `ancestors`. If `True` returns score as well. history: Whether to output the history of the filter or only the last step. Returns: Dictionary: - **x_particles** - JAX array containing the state variable particles at the last time point (leading dimension `n_particles`) or at all time points (leading dimensions `(n_obs, n_particles)` if `history=True`. - **logw** - JAX array containing unnormalized log weights at the last time point (dimensions `n_particles`) or at all time points (dimensions (n_obs, n_particles)`) if `history=True`. - **loglik** - The particle filter loglikelihood evaluated at `theta`. - **score** - Optional 1D JAX array of size `n_theta = length(theta)` containing the estimated score at `theta`. - **fisher** - Optional JAX array of shape `(n_theta, n_theta)` containing the estimated observed fisher information at `theta`. - **resample_out** - If `history=True`, a dictionary of additional outputs from `resampler` function. The leading dimension of each element of the dictionary has leading dimension `n_obs-1`, since these additional outputs do not apply to the first time point `t=0`. """ n_obs = y_meas.shape[0] has_acc = score or fisher # accumulator for derivatives def accumulator(x_prev, x_curr, y_curr, acc_prev): """ Accumulator for derivative calculations. Args: acc_prev: Dictionary with elements alpha and optionally beta. Returns: Dictionary with elements: - logw_targ: logw_bar + state_lpdf. - logw_prop: logw_aux + prop_lpdf. - alpha: optional current value of alpha_bar. - beta: optional current value of beta_bar. """ grad_meas = jax.grad(model.meas_lpdf, argnums=2) grad_state = jax.grad(model.state_lpdf, argnums=2) alpha = grad_meas(y_curr, x_curr, theta) + \ grad_state(x_curr, x_prev, theta) + \ acc_prev["alpha"] acc_curr = {"alpha": alpha} if fisher: hess_meas = jax.jacfwd(jax.jacrev(model.meas_lpdf, argnums=2), argnums=2) hess_state = jax.jacfwd(jax.jacrev(model.state_lpdf, argnums=2), argnums=2) beta = hess_meas(y_curr, x_curr, theta) + \ hess_state(x_curr, x_prev, theta) + \ acc_prev["beta"] acc_curr["beta"] = beta return acc_curr # internal functions for vmap def pf_step(key, x_prev, y_curr): return model.pf_step(key=key, x_prev=x_prev, y_curr=y_curr, theta=theta) def pf_init(key): return model.pf_init(key=key, y_init=y_meas[0], theta=theta) def pf_aux(logw_prev, x_prev, y_curr): """ Add the log-density for auxillary sampling `logw_aux` to the log-weights from the previous step `logw_prev`. The auxillary log-density is given by model.pf_aux(). If this method is missing, `logw_aux` is set to 0. """ if callable(getattr(model, "pf_aux", None)): logw_aux = model.pf_aux( x_prev=x_prev, y_curr=y_curr, theta=theta ) return logw_aux + logw_prev else: return logw_prev def pf_acc(acc_prev, x_prev, x_curr, y_curr): acc_curr = accumulator( x_prev=x_prev, x_curr=x_curr, y_curr=y_curr, theta=theta ) return tree_add(tree1=acc_prev, tree2=acc_curr) # lax.scan stepping function def filter_step(carry, t): # 1. sample particles from previous time point key, subkey = random.split(carry["key"]) # augment previous weights with auxiliary weights logw_aux = jax.vmap( pf_aux, in_axes=(0, 0, None) )(carry["logw"], carry["x_particles"], y_meas[t]) # resampled particles resample_out = resampler( key=subkey, x_particles_prev=carry["x_particles"], logw=logw_aux ) # 2. sample particles for current time point key, *subkeys = random.split(key, num=n_particles+1) x_particles, logw = jax.vmap( pf_step, in_axes=(0, 0, None) )(jnp.array(subkeys), resample_out["x_particles"], y_meas[t]) if has_acc: # 3. accumulate values for score and/or fisher information acc_prev = {"alpha": carry["alpha"]} if fisher: acc_prev["beta"] = carry["beta"] # resample acc_prev acc_prev = tree_shuffle( tree=acc_prev, index=resample_out["ancestors"] ) # add new values to get acc_curr acc_curr = jax.vmap( accumulator, in_axes=(0, 0, None, 0) )(resample_out["x_particles"], x_particles, y_meas[t], acc_prev) # output res_carry = { "x_particles": x_particles, "logw": logw, "key": key, "loglik": carry["loglik"] + jsp.special.logsumexp(logw) } if has_acc: res_carry.update(acc_curr) if history: res_stack = {k: res_carry[k] for k in ["x_particles", "logw"]} if set(["x_particles"]) < resample_out.keys(): res_stack["resample_out"] = rm_keys( x=resample_out, keys="x_particles" ) else: res_stack = None return res_carry, res_stack # lax.scan initial value key, *subkeys = random.split(key, num=n_particles+1) # initial particles and weights x_particles, logw = jax.vmap( pf_init )(jnp.array(subkeys)) filter_init = { "x_particles": x_particles, "logw": logw, "loglik": jsp.special.logsumexp(logw), "key": key, } if has_acc: # dummy initialization for derivatives filter_init["alpha"] = jnp.zeros((n_particles, theta.size)) if fisher: filter_init["beta"] = jnp.zeros( (n_particles, theta.size, theta.size) ) # # dummy initialization for resampler # init_res = resampler(key, x_particles, logw) # init_res = tree_zeros(rm_keys(init_res, ["x_particles"])) # if set("x_particles") < init_res.keys(): # filter_init["resample_out"] # if has_acc: # # dummy initialization for accumulator # init_acc = jax.vmap( # accumulator, # in_axes=(0, 0, None, None) # )(x_particles, x_particles, y_meas[0], theta) # init_acc = tree_zeros(init_acc) # filter_init.update(init_acc) # lax.scan itself last, full = lax.scan(filter_step, filter_init, jnp.arange(1, n_obs)) # format output if history: # append initial values of x_particles and logw full["x_particles"] = jnp.concatenate([ filter_init["x_particles"][None], full["x_particles"] ]) full["logw"] = jnp.concatenate([ filter_init["logw"][None], full["logw"] ]) # if has_acc: # # append derivative calculations # full["score"] = last["score"] else: full = last # calculate loglikelihood full["loglik"] = last["loglik"] - n_obs * jnp.log(n_particles) if has_acc: if not fisher: # calculate score only full["score"] = tree_mean(last["alpha"], last["logw"]) # alpha = jax.vmap( # jnp.multiply # )(prob, full["score"]) # full["score"] = jnp.sum(alpha, axis=0) else: # calculate score and fisher information prob = logw_to_prob(last["logw"]) alpha = last["alpha"] beta = last["beta"] alpha, gamma = jax.vmap( lambda p, a, b: (p * a, p * (jnp.outer(a, a) + b)) )(prob, alpha, beta) alpha = jnp.sum(alpha, axis=0) hess = jnp.sum(gamma, axis=0) - jnp.outer(alpha, alpha) full["score"] = alpha full["fisher"] = -hess return full
[docs]def particle_filter_rb(model, key, y_meas, theta, n_particles, resampler=resample_multinomial, score=False, fisher=False, history=False): r""" Rao-Blackwellized particle filter. Notes: - Algorithm 2 of Poyiadjis et al 2011. - Can optionally use an auxiliary particle filter if `model.pf_aux()` is provided. Args: model: Object specifying the state-space model having the following methods: - `pf_init : (key, y_init, theta) -> (x_particles, logw)`: For sampling and calculating log-weights for the initial latent variable. - `step_sample : (key, x_prev, y_curr, theta) -> x_curr`: Sampling from the proposal distribution for each subsequent latent variable. - `step_lpdf : (x_curr, x_prev, y_curr, theta) -> logw`: Calculate log-weights for each subsequent latent variable. - `state_lpdf : (x_curr, x_prev, theta) -> lpdf`: Calculates the log-density of the state model. - `meas_lpdf : (y_curr, x_curr, theta) -> lpdf`: Calculates the log-density of the measurement model. - `pf_aux : (x_prev, y_curr, theta) -> logw`: Optional method providing look-forward log-weights of the auxillary particle filter. key: PRNG key. y_meas: JAX array with leading dimension `n_obs` containing the measurement variables `y_meas = (y_0, ..., y_T)`, where `T = n_obs-1`. theta: Parameter value. n_particles: Number of particles. resampler: Function used at step `t` to obtain sample of particles from `p(x_{t} | y_{0:t}, theta)` out of a sample of particles from `p(x_{t-1} | y_{0:t-1}, theta)`. The argument signature is `resampler(x_particles, logw, key)`, and the return value is a dictionary with mandatory element `x_particles` and optional elements that get carried to the next step `t+1` via `lax.scan()`. score: Whether or not to return an estimate of the score function at `theta`. fisher: Whether or not to return an estimate of the Fisher information at `theta`. If `True` returns score as well. history: Whether to output the history of the filter or only the last step. Returns: Dictionary: - **x_particles** - JAX array containing the state variable particles at the last time point (leading dimension `n_particles`) or at all time points (leading dimensions `(n_obs, n_particles)` if `history=True`. - **logw_bar** - JAX array containing unnormalized log weights at the last time point (dimensions `n_particles`) or at all time points (dimensions (n_obs, n_particles)`) if `history=True`. - **loglik** - The particle filter loglikelihood evaluated at `theta`. - **score** - Optional 1D JAX array of size `n_theta = length(theta)` containing the estimated score at `theta`. - **fisher** Optional JAX array of shape `(n_theta, n_theta)` containing the estimated observed fisher information at `theta`. - **resample_out** If `history=True`, a dictionary of additional outputs from `resampler` function. The leading dimension of each element of the dictionary has leading dimension `n_obs-1`, since these additional outputs do not apply to the first time point `t=0`. """ n_obs = y_meas.shape[0] has_acc = score or fisher # internal functions for vmap def accumulator(x_prev, x_curr, y_curr, acc_prev, logw_aux): """ Accumulator for weights and possibly derivative calculations. Args: acc_prev: Dictionary with elements logw_bar, and optionally alpha and beta. Returns: Dictionary with elements: - logw_targ: logw_bar + state_lpdf. - logw_prop: logw_aux + prop_lpdf. - alpha: optional current value of alpha_bar. - beta: optional current value of beta_bar. """ logw_bar = acc_prev["logw_bar"] logw_targ = model.state_lpdf(x_curr=x_curr, x_prev=x_prev, theta=theta) + \ logw_bar logw_prop = model.step_lpdf(x_curr=x_curr, x_prev=x_prev, y_curr=y_curr, theta=theta) + \ logw_aux acc_full = {"logw_targ": logw_targ, "logw_prop": logw_prop} if has_acc: grad_meas = jax.grad(model.meas_lpdf, argnums=2) grad_state = jax.grad(model.state_lpdf, argnums=2) alpha = grad_meas(y_curr, x_curr, theta) + \ grad_state(x_curr, x_prev, theta) + \ acc_prev["alpha"] acc_full["alpha"] = alpha if fisher: hess_meas = jax.jacfwd(jax.jacrev(model.meas_lpdf, argnums=2), argnums=2) hess_state = jax.jacfwd(jax.jacrev(model.state_lpdf, argnums=2), argnums=2) beta = jnp.outer(alpha, alpha) + \ hess_meas(y_curr, x_curr, theta) + \ hess_state(x_curr, x_prev, theta) + \ acc_prev["beta"] acc_full["beta"] = beta return acc_full def pf_prop(x_curr, x_prev, y_curr): """ Calculate log-density of the proposal distribution `x_curr ~ q(x_t | x_t-1, y_t, theta)`. """ return model.step_lpdf(x_curr=x_curr, x_prev=x_prev, y_curr=y_curr, theta=theta) def pf_step(key, x_prev, y_curr): """ Sample from the proposal distribution `x_curr ~ q(x_t | x_t-1, y_t, theta)`. """ return model.step_sample(key=key, x_prev=x_prev, y_curr=y_curr, theta=theta) # if callable(getattr(model, "pf_prop", None)): # x_curr = model.pf_prop(key=key, x_prev=x_prev, y_curr=y_curr, # theta=theta) # else: # x_curr, _ = model.pf_step(key=key, x_prev=x_prev, y_curr=y_curr, # theta=theta) # return x_curr def pf_init(key): return model.pf_init(key=key, y_init=y_meas[0], theta=theta) def pf_aux(logw_prev, x_prev, y_curr): """ Add the log-density for auxillary sampling `logw_aux` to the log-weights from the previous step `logw_prev`. The auxillary log-density is given by model.pf_aux(). If this method is missing, `logw_aux` is set to 0. """ if callable(getattr(model, "pf_aux", None)): logw_aux = model.pf_aux( x_prev=x_prev, y_curr=y_curr, theta=theta ) return logw_aux + logw_prev else: return logw_prev def pf_bar(x_prev, x_curr, y_curr, acc_prev, logw_aux): """ Update calculations relating to logw_bar. This is the vmap version, which does the loop over x_prev here and the loop over x_curr in filter_step. Returns: Dictionary with elements: - logw_bar: rao-blackwellized weights. - alpha: optional current value of alpha. - beta: optional current value of beta. """ # Calculate the accumulated values looping over x_prev acc_full = jax.vmap( accumulator, in_axes=(0, None, None, 0, 0) )(x_prev, x_curr, y_curr, acc_prev, logw_aux) # Calculate logw_tilde and logw_bar logw_targ, logw_prop = acc_full["logw_targ"], acc_full["logw_prop"] logw_tilde = jsp.special.logsumexp(logw_targ) + \ model.meas_lpdf( y_curr=y_curr, x_curr=x_curr, theta=theta ) logw_bar = logw_tilde - jsp.special.logsumexp(logw_prop) acc_curr = {"logw_bar": logw_bar} if has_acc: # weight the derivatives grad_curr = tree_mean( tree=rm_keys(acc_full, ["logw_targ", "logw_prop"]), logw=logw_targ ) if fisher: grad_curr["beta"] = grad_curr["beta"] - \ jnp.outer(grad_curr["alpha"], grad_curr["alpha"]) acc_curr.update(grad_curr) return acc_curr # lax.scan stepping function def filter_step(carry, t): # 1. sample particles from previous time point key, subkey = random.split(carry["key"]) # augment previous weights with auxiliary weights logw_aux = jax.vmap( pf_aux, in_axes=(0, 0, None) )(carry["logw_bar"], carry["x_particles"], y_meas[t]) # resampled particles resample_out = resampler( key=subkey, x_particles_prev=carry["x_particles"], logw=logw_aux ) # 2. sample particles for current timepoint key, *subkeys = random.split(key, num=n_particles+1) x_particles = jax.vmap( pf_step, in_axes=(0, 0, None) )(jnp.array(subkeys), resample_out["x_particles"], y_meas[t]) # 3. compute all double summations acc_prev = {"logw_bar": carry["logw_bar"]} if has_acc: acc_prev["alpha"] = carry["alpha"] if fisher: acc_prev["beta"] = carry["beta"] acc_curr = jax.vmap( pf_bar, in_axes=(None, 0, None, None, None) )(carry["x_particles"], x_particles, y_meas[t], acc_prev, logw_aux) # output res_carry = { "x_particles": x_particles, "loglik": carry["loglik"] + jsp.special.logsumexp(acc_curr["logw_bar"]), "key": key # "resample_out": rm_keys(resample_out, "x_particles") } res_carry.update(acc_curr) if history: res_stack = {k: res_carry[k] for k in ["x_particles", "logw_bar"]} if set(["x_particles"]) < resample_out.keys(): res_stack["resample_out"] = rm_keys( x=resample_out, keys="x_particles" ) else: res_stack = None return res_carry, res_stack # lax.scan initial value key, *subkeys = random.split(key, num=n_particles+1) # initial particles and weights x_particles, logw = jax.vmap( pf_init )(jnp.array(subkeys)) # # dummy initialization for resampler # init_res = resampler(key, x_particles, logw) # init_res = tree_zeros(rm_keys(init_res, ["x_particles"])) filter_init = { "x_particles": x_particles, "loglik": jsp.special.logsumexp(logw), "logw_bar": logw, "key": key # "resample_out": init_res } if has_acc: # dummy initialization for derivatives filter_init["alpha"] = jnp.zeros((n_particles, theta.size)) if fisher: filter_init["beta"] = jnp.zeros( (n_particles, theta.size, theta.size) ) # lax.scan itself last, full = lax.scan(filter_step, filter_init, jnp.arange(1, n_obs)) # format output if history: # append initial values of x_particles and logw full["x_particles"] = jnp.concatenate([ filter_init["x_particles"][None], full["x_particles"] ]) full["logw_bar"] = jnp.concatenate([ filter_init["logw_bar"][None], full["logw_bar"] ]) else: full = last # calculate loglikelihood full["loglik"] = last["loglik"] - n_obs * jnp.log(n_particles) if has_acc: # logw_bar = last["logw_bar"] if not fisher: # calculate score only full["score"] = tree_mean(last["alpha"], last["logw_bar"]) else: # calculate score and fisher information prob = logw_to_prob(last["logw_bar"]) alpha = last["alpha"] beta = last["beta"] alpha, gamma = jax.vmap( lambda p, a, b: (p * a, p * (jnp.outer(a, a) + b)) )(prob, alpha, beta) alpha = jnp.sum(alpha, axis=0) hess = jnp.sum(gamma, axis=0) - jnp.outer(alpha, alpha) full["score"] = alpha full["fisher"] = -hess return full