pfjax.particle_resamplers

Module Contents

Functions

resample_multinomial(key, x_particles_prev, logw)

Particle resampler.

resample_mvn(key, x_particles_prev, logw)

Particle resampler with Multivariate Normal approximation.

resample_ot(key, x_particles_prev, logw[, scaled, ...])

Particle resampler using optimal transport.

pfjax.particle_resamplers.resample_multinomial(key, x_particles_prev, logw)[source]

Particle resampler.

This basic one just does a multinomial sampler, i.e., sample with replacement proportional to weights.

Parameters
  • key – PRNG key.

  • x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.

  • logw – Vector of corresponding n_particles unnormalized log-weights.

Returns

  • x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step. These are sampled with replacement from x_particles_prev with probability vector exp(logw) / sum(exp(logw)).

  • ancestors: Vector of n_particles integers between 0 and n_particles-1 giving the index of each element of x_particles_prev corresponding to the elements of x_particles.

Return type

A dictionary with elements

pfjax.particle_resamplers.resample_mvn(key, x_particles_prev, logw)[source]

Particle resampler with Multivariate Normal approximation.

Parameters
  • key – PRNG key.

  • x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.

  • logw – Vector of corresponding n_particles unnormalized log-weights.

Returns

  • x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step.

  • mvn_mean: Vector of length n_state = prod(x_particles.shape[1:]) representing the mean of the MVN.

  • mvn_cov: Matrix of size n_state x n_state representing the covariance matrix of the MVN.

Return type

A dictionary with elements

pfjax.particle_resamplers.resample_ot(key, x_particles_prev, logw, scaled=True, pointcloud_kwargs={}, sinkhorn_kwargs={})[source]

Particle resampler using optimal transport.

Based on Algorithms 2 and 3 of Corenflos et al 2021 <https://arxiv.org/abs/2102.07850>.

Notes:

  • Argument jit to ott.solvers.linear.sinkhorn.sinkhorn() is ignored, i.e., always set to False.

  • Both sinkhorn_kwargs and pointcloud_kwargs are shallow copied inside the function to impure function effects.

Parameters
  • key – PRNG key.

  • x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.

  • logw – Vector of corresponding n_particles unnormalized log-weights.

  • scaled – Whether or not to divide x_particles_prev by sqrt(x_particles_prev.shape[1]) * max(std(x_particles_prev, axis=0)), after reshaping x_particles into a 2D array with leading dimension of size n_particles. If True overrides any value of pointcloud_kwargs[“scale_cost”].

  • pointcloud_kwargs – Dictionary of additional arguments to ott.pointcloud.PointCloud().

  • sinkhorn_kwargs – Dictionary of additional arguments to ott.solvers.linear.sinkhorn.Sinkhorn().

Returns

  • x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step.

  • sink: An object of type ott.solvers.linear.sinkhorn.SinkhornOutput, as returned by ott.solvers.linear.sinkhorn.Sinkhorn().

Return type

A dictionary with elements