pfjax.sde

Generic methods for SDEs.

Notes:

  • euler_{sim/lpdf}_{diag/var}() are designed to be used independently of the SDE base class. If we abandon this requirement then we can avoid quite a bit of code duplication by having these functions do one data point only and putting the vmap/scan constructs into the state_{sample/lpdf} methods.

Module Contents

Classes

SDEModel

Base class for SDE models.

Functions

euler_sim_diag(key, x, dt, drift, diff, theta)

Simulate SDE with diagonal diffusion using Euler-Maruyama discretization.

euler_lpdf_diag(x_curr, x_prev, dt, drift, diff, theta)

Calculate the log PDF of observations from an SDE with diagonal diffusion using the Euler-Maruyama discretization.

euler_sim_var(key, x, dt, drift, diff, theta)

Simulate SDE with dense diffusion using Euler-Maruyama discretization.

euler_lpdf_var(x_curr, x_prev, dt, drift, diff, theta)

Calculate the log PDF of observations from an SDE with dense diffusion using the Euler-Maruyama discretization.

pfjax.sde.euler_sim_diag(key, x, dt, drift, diff, theta)[source]

Simulate SDE with diagonal diffusion using Euler-Maruyama discretization.

Parameters
  • key – PRNG key.

  • x – Initial value of the SDE. A vector of size n_dims.

  • dt – Interobservation time.

  • drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.

  • diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.

  • theta – Parameter value.

Returns

Simulated SDE values. A vector of size n_dims.

pfjax.sde.euler_lpdf_diag(x_curr, x_prev, dt, drift, diff, theta)[source]

Calculate the log PDF of observations from an SDE with diagonal diffusion using the Euler-Maruyama discretization.

Parameters
  • x_curr – Current SDE observations. A vector of size n_dims.

  • x_prev – Previous SDE observations. A vector of size n_dims.

  • dt – Interobservation time.

  • drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.

  • diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.

  • theta – Parameter value.

Returns

The log-density of the SDE observations.

pfjax.sde.euler_sim_var(key, x, dt, drift, diff, theta)[source]

Simulate SDE with dense diffusion using Euler-Maruyama discretization.

Parameters
  • key – PRNG key.

  • x – Initial value of the SDE. A vector of size n_dims.

  • dt – Interobservation time.

  • drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.

  • diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.

  • theta – Parameter value.

Returns

Simulated SDE values. A vector of size n_dims.

pfjax.sde.euler_lpdf_var(x_curr, x_prev, dt, drift, diff, theta)[source]

Calculate the log PDF of observations from an SDE with dense diffusion using the Euler-Maruyama discretization.

Parameters
  • x_curr – Current SDE observations. A vector of size n_dims.

  • x_prev – Previous SDE observations. A vector of size n_dims.

  • dt – Interobservation time.

  • drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.

  • diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.

  • theta – Parameter value.

Returns

The log-density of the SDE observations.

class pfjax.sde.SDEModel(dt, n_res, diff_diag)[source]

Bases: object

Base class for SDE models.

This class sets up a PF model with methods state_lpdf(), state_sim(), and pf_step() automatically determined from user-specified SDE drift and diffusion functions.

For computational efficiency, the user can also specify whether or not the diffusion is diagonal. This will set up methods euler_sim() and euler_lpdf() supplied at instantiation time from either euler_{sim/lpdf}_diag() or euler_{sim/lpdf}_var(), with arguments identical to those of the free functions except drift and diff, which are supplied by self.drift() and self.diff().

For pf_step(), a bootstrap filter is assumed by default, for which the user needs to specify meas_lpdf().

Notes:

  • Currently contains _state_sample_for() and _state_lpdf_for() for testing purposes. May want to move these elsewhere at some point to obfuscate from users…

  • Should derive from pf.BaseModel.

is_valid(x, theta)[source]

Checks whether SDE observations are valid.

Parameters
  • x – SDE variables. A vector of size n_dims.

  • theta – Parameter value.

Returns

Whether or not x is valid SDE data.

is_valid_state(x, theta)[source]

Checks whether SDE latent variables are valid.

Applies is_valid() to each of the n_res SDE variables, and also checks for nans.

Parameters
  • x – State variable of size n_res x n_dims.

  • theta – Parameter value.

Returns

Whether or not all n_res latent variables are valid.

state_lpdf(x_curr, x_prev, theta)[source]

Calculates the log-density of p(x_curr | x_prev, theta).

Parameters
  • x_curr – State variable at current time t.

  • x_prev – State variable at previous time t-1.

  • theta – Parameter value.

Returns

The log-density of p(x_curr | x_prev, theta).

state_sample(key, x_prev, theta)[source]

Samples from x_curr ~ p(x_curr | x_prev, theta).

Parameters
  • key – PRNG key.

  • x_prev – State variable at previous time t-1.

  • theta – Parameter value.

Returns

x_curr ~ p(x_curr | x_prev, theta).

Return type

Sample of the state variable at current time t

pf_step(key, x_prev, y_curr, theta)[source]

Update particle and calculate log-weight for a bootstrap particle filter. FIXME: This method is completely generic, i.e., is not specific to SDEs. May wish to put it elsewhere…

Parameters
  • key – PRNG key.

  • x_prev – State variable at previous time t-1.

  • y_curr – Measurement variable at current time t.

  • theta – Parameter value.

Returns

Sample of the state variable at current time t: x_curr ~ q(x_curr). - logw: The log-weight of x_curr.

Return type

  • x_curr

bridge_step(key, x_prev, y_curr, theta, Y, A, Omega)[source]

Update particle and calculate log-weight for a particle filter with MVN bridge proposals.

Notes:

  • The measurement input is Y instead of y_meas. The reason is that the proposal can be used with carefully chosen Y = g(y_meas) when the measurement error model is not y_meas ~ N(A x_state, Omega).

  • Only implements the general version with arbitrary A and Omega. Specific cases can be done more rapidly, but it’s not clear where to put these different methods. Inside the class? As free functions?

  • Duplicates a lot of code from mvn_bridge_pars(), because Sigma_Y and mu_Y inside the latter can be computed very easily here.

  • Computes the Euler part of the log-weights using vmap after the bridge part which used lax.scan(). On one core, it’s definitely faster to do both inside lax.scan(). On multiple cores that may or may not be the case, but probably would need quite a few cores see the speed increase. However, it’s unlikely that we’ll explicitly parallelize across cores for this, since the parallellization would typically be over particles.

  • The drift and diffusion functions are each calculated twice, once for proposal and once for Euler. This is somewhat inefficient, but to circumvent this would need to redesign euler_lpdf()

Parameters
  • key – PRNG key.

  • x_prev – State variable at previous time t-1.

  • y_curr – Measurement variable at current time t.

  • theta – Parameter value.

Returns

  • x_curr - Sample of the state variable at current time t: x_curr ~ q(x_curr).

  • logw - The log-weight of x_curr.

Return type

Tuple