pfjax.utils
Utility functions.
Module Contents
Functions
|
Calculate normalized probabilities from unnormalized log weights. |
|
Remove specified keys from given dict. |
|
Convert a PyTree into a 2D JAX array. |
|
Add two pytrees leafwise. |
|
Weighted mean of each leaf of a pytree along leading dimension. |
|
Shuffle the leading dimension of each leaf of a pytree by values in index. |
|
Fill pytree with zeros. |
Remove last element of each leaf of pytree. |
|
Remove first element of each leaf of pytree. |
|
Keep only last element of each leaft of pytree. |
|
|
Append first to start of each leaf of x along 1st dimension. |
|
Append last to end of each leaf of x along 1st dimension. |
- pfjax.utils.logw_to_prob(logw)[source]
Calculate normalized probabilities from unnormalized log weights.
- Parameters
logw – Vector of n_particles unnormalized log-weights.
- Returns
Vector of n_particles normalized weights that sum to 1.
- pfjax.utils.tree_array2d(x, shape0=None)[source]
Convert a PyTree into a 2D JAX array.
Starts by converting each leaf array to a 2D JAX array with same leading dimension. Then concatenates these arrays along axis=1. Assumes the leading dimension of each leaf is the same.
Notes:
This function returns a tuple containing a Callable, so can’t be jitted directly. Can however be called in jitted code so long as the output is a PyTree.
- Parameters
x – A Pytree.
shape0 – Optional value of the leading dimension. If None is deduced from x.
- Returns
array2d - A two dimensional JAX array.
unravel_fn - A Callable to reconstruct the original PyTree.
- Return type
tuple
- pfjax.utils.tree_mean(tree, logw)[source]
Weighted mean of each leaf of a pytree along leading dimension.
- pfjax.utils.tree_shuffle(tree, index)[source]
Shuffle the leading dimension of each leaf of a pytree by values in index.