Source code for pfjax.particle_smooth

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from jax import random
from jax import lax
# from jax.experimental.maps import xmap
from .utils import logw_to_prob


[docs]def particle_smooth(key, logw, x_particles, ancestors): r""" Draw a sample from `p(x_state | x_meas, theta)` from a particle filter with multinomial resampling. Args: key: PRNG key. logw: Vector of `n_particles` unnormalized log-weights at the last time point `t = n_obs-1`. x_particles: JAX array with leading dimensions `(n_obs, n_particles)` containing the state variable particles. ancestors: JAX integer array of shape `(n_obs-1, n_particles)` where each element gives the index of the particle's ancestor at the previous time point. Returns: JAX Array: An array with leading dimension `n_obs` sampled from `p(x_{0:T} | y_{0:T}, theta)`. """ n_particles = logw.size n_obs = x_particles.shape[0] prob = logw_to_prob(logw) # wgt = jnp.exp(logw - jnp.max(logw)) # prob = wgt / jnp.sum(wgt) # lax.scan setup # scan function def fun(carry, t): # ancestor particle index i_part = ancestors[t, carry["i_part"]] res = {"i_part": i_part} return res, res # res_carry = {"i_part": i_part} # res_stack = {"i_part": i_part, "x_state": x_particles[t, i_part]} # return res_carry, res_stack # scan initial value init = { "i_part": random.choice(key, a=jnp.arange(n_particles), p=prob) } # lax.scan itself last, full = lax.scan(fun, init, jnp.flip(jnp.arange(n_obs-1))) # particle indices in forward order i_part = jnp.flip(jnp.append(init["i_part"], full["i_part"])) return x_particles[jnp.arange(n_obs), i_part, ...] # , i_part