pfjax.particle_resamplers
Module Contents
Functions
|
Particle resampler. |
|
Particle resampler with Multivariate Normal approximation. |
|
Particle resampler using optimal transport. |
- pfjax.particle_resamplers.resample_multinomial(key, x_particles_prev, logw)[source]
Particle resampler.
This basic one just does a multinomial sampler, i.e., sample with replacement proportional to weights.
- Parameters
key – PRNG key.
x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.
logw – Vector of corresponding n_particles unnormalized log-weights.
- Returns
x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step. These are sampled with replacement from x_particles_prev with probability vector exp(logw) / sum(exp(logw)).
ancestors: Vector of n_particles integers between 0 and n_particles-1 giving the index of each element of x_particles_prev corresponding to the elements of x_particles.
- Return type
A dictionary with elements
- pfjax.particle_resamplers.resample_mvn(key, x_particles_prev, logw)[source]
Particle resampler with Multivariate Normal approximation.
- Parameters
key – PRNG key.
x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.
logw – Vector of corresponding n_particles unnormalized log-weights.
- Returns
x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step.
mvn_mean: Vector of length n_state = prod(x_particles.shape[1:]) representing the mean of the MVN.
mvn_cov: Matrix of size n_state x n_state representing the covariance matrix of the MVN.
- Return type
A dictionary with elements
- pfjax.particle_resamplers.resample_ot(key, x_particles_prev, logw, scaled=True, pointcloud_kwargs={}, sinkhorn_kwargs={})[source]
Particle resampler using optimal transport.
Based on Algorithms 2 and 3 of Corenflos et al 2021 <https://arxiv.org/abs/2102.07850>.
Notes:
Argument jit to ott.solvers.linear.sinkhorn.sinkhorn() is ignored, i.e., always set to False.
Both sinkhorn_kwargs and pointcloud_kwargs are shallow copied inside the function to impure function effects.
- Parameters
key – PRNG key.
x_particles_prev – An ndarray with leading dimension n_particles consisting of the particles from the previous time step.
logw – Vector of corresponding n_particles unnormalized log-weights.
scaled – Whether or not to divide x_particles_prev by sqrt(x_particles_prev.shape[1]) * max(std(x_particles_prev, axis=0)), after reshaping x_particles into a 2D array with leading dimension of size n_particles. If True overrides any value of pointcloud_kwargs[“scale_cost”].
pointcloud_kwargs – Dictionary of additional arguments to ott.pointcloud.PointCloud().
sinkhorn_kwargs – Dictionary of additional arguments to ott.solvers.linear.sinkhorn.Sinkhorn().
- Returns
x_particles: An ndarray with leading dimension n_particles consisting of the particles from the current time step.
sink: An object of type ott.solvers.linear.sinkhorn.SinkhornOutput, as returned by ott.solvers.linear.sinkhorn.Sinkhorn().
- Return type
A dictionary with elements