import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from jax import random
from jax import lax
# from jax.experimental.maps import xmap
from .utils import logw_to_prob
[docs]def particle_smooth(key, logw, x_particles, ancestors):
r"""
Draw a sample from `p(x_state | x_meas, theta)` from a particle filter with multinomial resampling.
Args:
key: PRNG key.
logw: Vector of `n_particles` unnormalized log-weights at the last time point `t = n_obs-1`.
x_particles: JAX array with leading dimensions `(n_obs, n_particles)` containing the state variable particles.
ancestors: JAX integer array of shape `(n_obs-1, n_particles)` where each element gives the index of the particle's ancestor at the previous time point.
Returns:
JAX Array: An array with leading dimension `n_obs` sampled from `p(x_{0:T} | y_{0:T}, theta)`.
"""
n_particles = logw.size
n_obs = x_particles.shape[0]
prob = logw_to_prob(logw)
# wgt = jnp.exp(logw - jnp.max(logw))
# prob = wgt / jnp.sum(wgt)
# lax.scan setup
# scan function
def fun(carry, t):
# ancestor particle index
i_part = ancestors[t, carry["i_part"]]
res = {"i_part": i_part}
return res, res
# res_carry = {"i_part": i_part}
# res_stack = {"i_part": i_part, "x_state": x_particles[t, i_part]}
# return res_carry, res_stack
# scan initial value
init = {
"i_part": random.choice(key, a=jnp.arange(n_particles), p=prob)
}
# lax.scan itself
last, full = lax.scan(fun, init, jnp.flip(jnp.arange(n_obs-1)))
# particle indices in forward order
i_part = jnp.flip(jnp.append(init["i_part"], full["i_part"]))
return x_particles[jnp.arange(n_obs), i_part, ...] # , i_part