pfjax.utils

Utility functions.

Module Contents

Functions

logw_to_prob(logw)

Calculate normalized probabilities from unnormalized log weights.

rm_keys(x, keys)

Remove specified keys from given dict.

tree_array2d(x[, shape0])

Convert a PyTree into a 2D JAX array.

tree_add(tree1, tree2)

Add two pytrees leafwise.

tree_mean(tree, logw)

Weighted mean of each leaf of a pytree along leading dimension.

tree_shuffle(tree, index)

Shuffle the leading dimension of each leaf of a pytree by values in index.

tree_zeros(tree)

Fill pytree with zeros.

tree_remove_last(x)

Remove last element of each leaf of pytree.

tree_remove_first(x)

Remove first element of each leaf of pytree.

tree_keep_last(x)

Keep only last element of each leaft of pytree.

tree_append_first(x, first)

Append first to start of each leaf of x along 1st dimension.

tree_append_last(x, last)

Append last to end of each leaf of x along 1st dimension.

pfjax.utils.logw_to_prob(logw)[source]

Calculate normalized probabilities from unnormalized log weights.

Parameters

logw – Vector of n_particles unnormalized log-weights.

Returns

Vector of n_particles normalized weights that sum to 1.

pfjax.utils.rm_keys(x, keys)[source]

Remove specified keys from given dict.

pfjax.utils.tree_array2d(x, shape0=None)[source]

Convert a PyTree into a 2D JAX array.

Starts by converting each leaf array to a 2D JAX array with same leading dimension. Then concatenates these arrays along axis=1. Assumes the leading dimension of each leaf is the same.

Notes:

  • This function returns a tuple containing a Callable, so can’t be jitted directly. Can however be called in jitted code so long as the output is a PyTree.

Parameters
  • x – A Pytree.

  • shape0 – Optional value of the leading dimension. If None is deduced from x.

Returns

  • array2d - A two dimensional JAX array.

  • unravel_fn - A Callable to reconstruct the original PyTree.

Return type

tuple

pfjax.utils.tree_add(tree1, tree2)[source]

Add two pytrees leafwise.

pfjax.utils.tree_mean(tree, logw)[source]

Weighted mean of each leaf of a pytree along leading dimension.

pfjax.utils.tree_shuffle(tree, index)[source]

Shuffle the leading dimension of each leaf of a pytree by values in index.

pfjax.utils.tree_zeros(tree)[source]

Fill pytree with zeros.

pfjax.utils.tree_remove_last(x)[source]

Remove last element of each leaf of pytree.

pfjax.utils.tree_remove_first(x)[source]

Remove first element of each leaf of pytree.

pfjax.utils.tree_keep_last(x)[source]

Keep only last element of each leaft of pytree.

pfjax.utils.tree_append_first(x, first)[source]

Append first to start of each leaf of x along 1st dimension.

pfjax.utils.tree_append_last(x, last)[source]

Append last to end of each leaf of x along 1st dimension.