pfjax.mcmc

MCMC algorithms for state space models.

Module Contents

Classes

AdaptiveMWG

Adaptive Metropolis-within-Gibbs.

class pfjax.mcmc.AdaptiveMWG(adapt_max=0.01, adapt_rate=0.5)[source]

Adaptive Metropolis-within-Gibbs.

Notes:

Design heavily inspired by BlackJAX. Perhaps to be fully integrated with it some day.

Parameters
  • adapt_max – Scalar or vector of maximum adaptation amounts.

  • adapt_rate – Scalar or vector of adaptation rates.

adapt(pars, accept)[source]

Update random walk standard deviations.

Parameters
  • pars

    Adaptation parameters. A dictionary with elements:

    • rw_sd: Vector of standard deviations for the random walk proposal on each component of position.

    • n_iter: Number of MWG iterations (cycles) so far.

    • n_accept: The number of accepted draws per component so far.

  • accept – Boolean vector indicating whether or not the latest proposal was accepted.

Returns

  • rw_sd - Vector of updated random walk standard deviations.

  • n_iter - Updated number of iterations, i.e., n_iter += 1.

  • n_accept - Updated number of accepted draws, i.e., n_accept += accept.

Return type

Dictionary

init(rw_sd)[source]

Initialize the adaptation parameters.

Parameters

rw_sd – A vector of initial standard deviations for the componentwise random walk proposal.

Returns

  • rw_sd - The vector of standard deviations for the componentwise random walk proposal.

  • n_iter - The number of MWG steps taken so far, which is zero.

  • n_accept - The number of draws of each component accepted so far, which is jnp.zeros_like(rw_sd).

Return type

Dictionary

step(key, position, logprob_fn, rw_sd, order=None)[source]

Update parameters via adaptive Metropolis-within-Gibbs.

Parameters
  • key – PRNG key.

  • position – The current position of the sampler.

  • logprob_fn – Function which takes a JAX array input position and returns a scalar corresponding to the log of the probability density at that input.

  • rw_sd – Vector of standard deviations for the componentwise random walk proposal.

  • order – Optional vector of integers between 0 and position.size indicating the order in which to update the components of position. Can use this to keep certain components fixed, randomize update order, etc.

Returns

  • position - The updated position.

  • accept - Boolean vector of length position.size indicating whether or not each proposal was accepted.

Return type

Tuple