Source code for pfjax.utils

"""
Utility functions.
"""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import jax.flatten_util as jfu


[docs]def logw_to_prob(logw): r""" Calculate normalized probabilities from unnormalized log weights. Args: logw: Vector of `n_particles` unnormalized log-weights. Returns: Vector of `n_particles` normalized weights that sum to 1. """ wgt = jnp.exp(logw - jnp.max(logw)) prob = wgt / jnp.sum(wgt) return prob
[docs]def rm_keys(x, keys): r""" Remove specified keys from given dict. """ return {k: x[k] for k in x.keys() if k not in keys}
# --- tree helper functions ----------------------------------------------------
[docs]def tree_array2d(x, shape0=None): r""" Convert a PyTree into a 2D JAX array. Starts by converting each leaf array to a 2D JAX array with same leading dimension. Then concatenates these arrays along `axis=1`. Assumes the leading dimension of each leaf is the same. **Notes:** - This function returns a tuple containing a Callable, so can't be jitted directly. Can however be called in jitted code so long as the output is a PyTree. Args: x: A Pytree. shape0: Optional value of the leading dimension. If `None` is deduced from `x`. Returns: tuple: - **array2d** - A two dimensional JAX array. - **unravel_fn** - A Callable to reconstruct the original PyTree. """ if shape0 is None: shape0 = jtu.tree_leaves(x)[0].shape[0] # leading dimension y, _unravel_fn = jfu.ravel_pytree(x) y = jnp.reshape(y, (shape0, -1)) def unravel_fn(array2d): return _unravel_fn(jnp.ravel(array2d)) return y, unravel_fn
[docs]def tree_add(tree1, tree2): r""" Add two pytrees leafwise. """ return jtu.tree_map(lambda x, y: x+y, tree1, tree2)
# return jtu.tree_map(lambda x, y: x+y[0], tree1, (tree2,))
[docs]def tree_mean(tree, logw): r""" Weighted mean of each leaf of a pytree along leading dimension. """ prob = logw_to_prob(logw) broad_mult = jtu.Partial(jax.vmap(jnp.multiply), prob) return jtu.tree_map( jtu.Partial(jnp.sum, axis=0), jtu.tree_map(broad_mult, tree) )
[docs]def tree_shuffle(tree, index): """ Shuffle the leading dimension of each leaf of a pytree by values in index. """ return jtu.tree_map(lambda x: x[index, ...], tree)
[docs]def tree_zeros(tree): r""" Fill pytree with zeros. """ return jtu.tree_map(lambda x: jnp.zeros_like(x), tree)
[docs]def tree_remove_last(x): """ Remove last element of each leaf of pytree. """ return jtu.tree_map(lambda _x: _x[:-1], x)
[docs]def tree_remove_first(x): """ Remove first element of each leaf of pytree. """ return jtu.tree_map(lambda _x: _x[1:], x)
[docs]def tree_keep_last(x): """ Keep only last element of each leaft of pytree. """ return jtu.tree_map(lambda _x: _x[-1], x)
[docs]def tree_append_first(x, first): """ Append `first` to start of each leaf of `x` along 1st dimension. """ return jtu.tree_map(lambda x0, _x: jnp.concatenate([x0[None], _x]), first, x)
[docs]def tree_append_last(x, last): """ Append `last` to end of each leaf of `x` along 1st dimension. """ return jtu.tree_map(lambda xl, _x: jnp.concatenate([_x, xl[None]]), last, x)