"""
MCMC algorithms for state space models.
"""
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
from jax import lax
[docs]class AdaptiveMWG:
r"""
Adaptive Metropolis-within-Gibbs.
**Notes:**
Design heavily inspired by |BlackJAX|_. Perhaps to be fully integrated with it some day.
Args:
adapt_max: Scalar or vector of maximum adaptation amounts.
adapt_rate: Scalar or vector of adaptation rates.
.. _BlackJAX: https://blackjax-devs.github.io/blackjax/index.html
.. |BlackJAX| replace:: **BlackJAX**
"""
def __init__(self, adapt_max=.01, adapt_rate=.5):
# self._n_params = n_params
self._adapt_max = adapt_max
self._adapt_rate = adapt_rate
self._targ_acc = .44
[docs] def adapt(self, pars, accept):
r"""
Update random walk standard deviations.
Args:
pars: Adaptation parameters. A dictionary with elements:
- `rw_sd`: Vector of standard deviations for the random walk proposal on each component of `position`.
- `n_iter`: Number of MWG iterations (cycles) so far.
- `n_accept`: The number of accepted draws per component so far.
accept: Boolean vector indicating whether or not the latest proposal was accepted.
Returns:
Dictionary:
- **rw_sd** - Vector of updated random walk standard deviations.
- **n_iter** - Updated number of iterations, i.e., ``n_iter += 1``.
- **n_accept** - Updated number of accepted draws, i.e., ``n_accept += accept``.
"""
rw_sd = pars["rw_sd"]
n_iter = pars["n_iter"] + 1.
n_accept = pars["n_accept"] + accept
accept_rate = (1. * n_accept) / n_iter
delta = jnp.power(n_iter, -self._adapt_rate)
delta = jnp.minimum(delta, self._adapt_max)
low_acc = jnp.sign(self._targ_acc - accept_rate)
return {
"rw_sd": jnp.exp(jnp.log(rw_sd) - delta * low_acc),
"n_iter": n_iter,
"n_accept": n_accept
}
[docs] def init(self, rw_sd):
r"""
Initialize the adaptation parameters.
Args:
rw_sd: A vector of initial standard deviations for the componentwise random walk proposal.
Returns:
Dictionary:
- **rw_sd** - The vector of standard deviations for the componentwise random walk proposal.
- **n_iter** - The number of MWG steps taken so far, which is zero.
- **n_accept** - The number of draws of each component accepted so far, which is ``jnp.zeros_like(rw_sd)``.
"""
return {
"rw_sd": rw_sd,
"n_iter": 0.,
"n_accept": jnp.zeros_like(rw_sd)
}
[docs] def step(self, key, position, logprob_fn, rw_sd, order=None):
r"""
Update parameters via adaptive Metropolis-within-Gibbs.
Args:
key: PRNG key.
position: The current position of the sampler.
logprob_fn: Function which takes a JAX array input `position` and returns a scalar corresponding to the log of the probability density at that input.
rw_sd: Vector of standard deviations for the componentwise random walk proposal.
order: Optional vector of integers between 0 and ``position.size`` indicating the order in which to update the components of `position`. Can use this to keep certain components fixed, randomize update order, etc.
Returns:
Tuple:
- **position** - The updated position.
- **accept** - Boolean vector of length ``position.size`` indicating whether or not each proposal was accepted.
"""
n_pos = position.size # number of components to update
if order is None:
order = jnp.arange(n_pos)
# lax.scan setup
def fun(carry, i):
lp_curr = carry["lp"]
position_curr = carry["position"]
key = carry["key"]
# 2 subkeys for each position: rw_jump and mh_accept
key, *subkeys = random.split(key, num=3)
# proposal
position_prop = position_curr.at[i].set(
position_curr[i] + rw_sd[i] * random.normal(key=subkeys[0])
)
# acceptance rate
lp_prop = logprob_fn(position_prop)
lrate = lp_prop - lp_curr
# update parameter draw
acc = random.bernoulli(key=subkeys[1],
p=jnp.minimum(1.0, jnp.exp(lrate)))
res = {
"position": position_curr.at[i].set(
position_prop[i] * acc + position_curr[i] * (1-acc)
),
"lp": lp_prop * acc + lp_curr * (1-acc),
# "accept": acc,
"key": key
}
stack = {"accept": acc}
return res, stack
# scan initial value
init = {
"position": position,
"lp": logprob_fn(position),
"key": key
}
# scan itself
last, full = jax.lax.scan(fun, init, order)
position = last["position"]
accept = full["accept"]
return position, accept