Make Discretizer use GaussianStateEvolution for increased filter compatibility#199
Make Discretizer use GaussianStateEvolution for increased filter compatibility#199mattlevine22 wants to merge 14 commits intomainfrom
Conversation
Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior. Made-with: Cursor
…ration protections
DanWaxman
left a comment
There was a problem hiding this comment.
I think this looks good in principle, maybe with some style comments. This should've run the hierarchical discretized smoke tests, but it would also be nice to run the 08_hierarchical_inference.ipynb tutorial notebook and ensure there's no qualitative change there (I don't expect there is).
| Supports batched time (and optional control) matching the previous | ||
| `EulerMaruyamaGaussianStateEvolution` behavior. |
| "loc": ndarray | ||
| Mean of next state(s): shape (dim_state,) or (num_timepoints, dim_state) | ||
| "cov": ndarray | ||
| Covariance of next state(s): shape (dim_state, dim_state) or (num_timepoints, dim_state, dim_state) |
There was a problem hiding this comment.
if this allows batched shapes, this should be reflected in the docstring array shapes
| The parent `GaussianStateEvolution.__call__` would evaluate the | ||
| Euler–Maruyama drift and diffusion twice. Under `jax.vmap` (e.g. | ||
| plate-batched cuthbert EKF), that split can change tracing/shapes. This | ||
| override matches the original one-step implementation. |
| em_result = _euler_maruyama_loc_cov(self.cte, x, u, t_now, t_next) | ||
| return dist.MultivariateNormal( | ||
| loc=em_result["loc"], covariance_matrix=em_result["cov"] | ||
| ) |
There was a problem hiding this comment.
I'm not so clear on the benefit of making this a helper function, to be honest. It seems like it can be reasonably in-lined here.
There was a problem hiding this comment.
We want it to construct F and cov in the __init__ method, as well as use it in the __call__ method
| Holds ``cte`` as an explicit field so `DynamicalModel` pytrees under | ||
| `numpyro.plate` still expose batched continuous-time parameters for | ||
| simulator slicing (closures over ``cte`` alone would hide those arrays). | ||
|
|
There was a problem hiding this comment.
This seems a bit in-the-weeds. Would prefer to just write that we hold cte for pytree compatibility.
|
This PR only adds support for time-varying |
| def _bm_dim_or_default(state_evolution: ContinuousTimeStateEvolution) -> int: | ||
| """Return Brownian dimension, defaulting to 1 when unspecified. | ||
|
|
||
| Args: | ||
| state_evolution: Continuous-time state evolution. | ||
|
|
||
| Returns: | ||
| Brownian motion dimension used by EM sampling. | ||
| """ | ||
| return int(state_evolution.bm_dim) if state_evolution.bm_dim is not None else 1 |
There was a problem hiding this comment.
Why don't we just set the corresponding state_evolution.bm_dim? Seems like needless indirection.
There was a problem hiding this comment.
Also I don't really see how 1 is a reasonable default. If anything, we should raise or set up a probe state pattern to get the right shape.
| def euler_maruyama_step_loc_cov( | ||
| state_evolution: ContinuousTimeStateEvolution, | ||
| x: Array, | ||
| u: Array | None, | ||
| t_now: Array, | ||
| dt: Array, | ||
| ) -> tuple[Array, Array]: | ||
| """Compute one EM moment step over a fixed `dt`. |
There was a problem hiding this comment.
This feels like a particularly simple function.
| Batched mode maps across the time axis, pairing | ||
| `x[:, i], u[:, i], t_now[i], t_next[i]` for each `i`. | ||
| Scalar inputs are promoted to a batch of size 1 internally and squeezed | ||
| back to single-transition outputs. |
There was a problem hiding this comment.
Isn't this counter to other places?
| f"got {type(dt)}." | ||
| ) | ||
|
|
||
|
|
| ``F`` and ``cov`` are optional constructor args so Equinox/dataclass-style | ||
| but we don't use them. |
Addresses issue #115
GaussianStateEvolutiona time-varying covariance field (or jax.array)GaussianStateEvolutionerrors if covariance is a callablediscretizers.pyto build aclass EulerMaruyamaGaussianStateEvolution(GaussianStateEvolution)tests_discretizers.pywhich runs Cuthbert EKF with a discretized SDE model.N.B. CD_Dynamax/Dynamax discrete-time integrations don't support time-varying
F(, .. t)anywayN.B.This PR only adds support for time-varying
GaussianStateEvolutionto enable discretized + discrete-time-filter patterns. It does NOT modify: [LinearGaussianStateEvolution,GaussianObservation,LinearGaussianObservation],