pfjax.sde
Generic methods for SDEs.
Notes:
euler_{sim/lpdf}_{diag/var}() are designed to be used independently of the SDE base class. If we abandon this requirement then we can avoid quite a bit of code duplication by having these functions do one data point only and putting the vmap/scan constructs into the state_{sample/lpdf} methods.
Module Contents
Classes
Base class for SDE models. |
Functions
|
Simulate SDE with diagonal diffusion using Euler-Maruyama discretization. |
|
Calculate the log PDF of observations from an SDE with diagonal diffusion using the Euler-Maruyama discretization. |
|
Simulate SDE with dense diffusion using Euler-Maruyama discretization. |
|
Calculate the log PDF of observations from an SDE with dense diffusion using the Euler-Maruyama discretization. |
- pfjax.sde.euler_sim_diag(key, x, dt, drift, diff, theta)[source]
Simulate SDE with diagonal diffusion using Euler-Maruyama discretization.
- Parameters
key – PRNG key.
x – Initial value of the SDE. A vector of size n_dims.
dt – Interobservation time.
drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.
diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.
theta – Parameter value.
- Returns
Simulated SDE values. A vector of size n_dims.
- pfjax.sde.euler_lpdf_diag(x_curr, x_prev, dt, drift, diff, theta)[source]
Calculate the log PDF of observations from an SDE with diagonal diffusion using the Euler-Maruyama discretization.
- Parameters
x_curr – Current SDE observations. A vector of size n_dims.
x_prev – Previous SDE observations. A vector of size n_dims.
dt – Interobservation time.
drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.
diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.
theta – Parameter value.
- Returns
The log-density of the SDE observations.
- pfjax.sde.euler_sim_var(key, x, dt, drift, diff, theta)[source]
Simulate SDE with dense diffusion using Euler-Maruyama discretization.
- Parameters
key – PRNG key.
x – Initial value of the SDE. A vector of size n_dims.
dt – Interobservation time.
drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.
diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.
theta – Parameter value.
- Returns
Simulated SDE values. A vector of size n_dims.
- pfjax.sde.euler_lpdf_var(x_curr, x_prev, dt, drift, diff, theta)[source]
Calculate the log PDF of observations from an SDE with dense diffusion using the Euler-Maruyama discretization.
- Parameters
x_curr – Current SDE observations. A vector of size n_dims.
x_prev – Previous SDE observations. A vector of size n_dims.
dt – Interobservation time.
drift – Drift function having signature drift(x, theta) and returning a vector of size n_dims.
diff – Diffusion function having signature diff(x, theta) and returning a vector of size n_dims.
theta – Parameter value.
- Returns
The log-density of the SDE observations.
- class pfjax.sde.SDEModel(dt, n_res, diff_diag)[source]
Bases:
objectBase class for SDE models.
This class sets up a PF model with methods state_lpdf(), state_sim(), and pf_step() automatically determined from user-specified SDE drift and diffusion functions.
For computational efficiency, the user can also specify whether or not the diffusion is diagonal. This will set up methods euler_sim() and euler_lpdf() supplied at instantiation time from either euler_{sim/lpdf}_diag() or euler_{sim/lpdf}_var(), with arguments identical to those of the free functions except drift and diff, which are supplied by self.drift() and self.diff().
For pf_step(), a bootstrap filter is assumed by default, for which the user needs to specify meas_lpdf().
Notes:
Currently contains _state_sample_for() and _state_lpdf_for() for testing purposes. May want to move these elsewhere at some point to obfuscate from users…
Should derive from pf.BaseModel.
- is_valid(x, theta)[source]
Checks whether SDE observations are valid.
- Parameters
x – SDE variables. A vector of size n_dims.
theta – Parameter value.
- Returns
Whether or not x is valid SDE data.
- is_valid_state(x, theta)[source]
Checks whether SDE latent variables are valid.
Applies is_valid() to each of the n_res SDE variables, and also checks for nans.
- Parameters
x – State variable of size n_res x n_dims.
theta – Parameter value.
- Returns
Whether or not all n_res latent variables are valid.
- state_lpdf(x_curr, x_prev, theta)[source]
Calculates the log-density of p(x_curr | x_prev, theta).
- Parameters
x_curr – State variable at current time t.
x_prev – State variable at previous time t-1.
theta – Parameter value.
- Returns
The log-density of p(x_curr | x_prev, theta).
- state_sample(key, x_prev, theta)[source]
Samples from x_curr ~ p(x_curr | x_prev, theta).
- Parameters
key – PRNG key.
x_prev – State variable at previous time t-1.
theta – Parameter value.
- Returns
x_curr ~ p(x_curr | x_prev, theta).
- Return type
Sample of the state variable at current time t
- pf_step(key, x_prev, y_curr, theta)[source]
Update particle and calculate log-weight for a bootstrap particle filter. FIXME: This method is completely generic, i.e., is not specific to SDEs. May wish to put it elsewhere…
- Parameters
key – PRNG key.
x_prev – State variable at previous time t-1.
y_curr – Measurement variable at current time t.
theta – Parameter value.
- Returns
Sample of the state variable at current time t: x_curr ~ q(x_curr). - logw: The log-weight of x_curr.
- Return type
x_curr
- bridge_step(key, x_prev, y_curr, theta, Y, A, Omega)[source]
Update particle and calculate log-weight for a particle filter with MVN bridge proposals.
Notes:
The measurement input is Y instead of y_meas. The reason is that the proposal can be used with carefully chosen Y = g(y_meas) when the measurement error model is not y_meas ~ N(A x_state, Omega).
Only implements the general version with arbitrary A and Omega. Specific cases can be done more rapidly, but it’s not clear where to put these different methods. Inside the class? As free functions?
Duplicates a lot of code from mvn_bridge_pars(), because Sigma_Y and mu_Y inside the latter can be computed very easily here.
Computes the Euler part of the log-weights using vmap after the bridge part which used lax.scan(). On one core, it’s definitely faster to do both inside lax.scan(). On multiple cores that may or may not be the case, but probably would need quite a few cores see the speed increase. However, it’s unlikely that we’ll explicitly parallelize across cores for this, since the parallellization would typically be over particles.
The drift and diffusion functions are each calculated twice, once for proposal and once for Euler. This is somewhat inefficient, but to circumvent this would need to redesign euler_lpdf()…
- Parameters
key – PRNG key.
x_prev – State variable at previous time t-1.
y_curr – Measurement variable at current time t.
theta – Parameter value.
- Returns
x_curr - Sample of the state variable at current time t: x_curr ~ q(x_curr).
logw - The log-weight of x_curr.
- Return type
Tuple