Add fast SDE solver backends and filter-simulator coverage.#178
Add fast SDE solver backends and filter-simulator coverage.#178
Conversation
Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior. Made-with: Cursor
|
This PR seems ready, and I think its tests are actually passing because |
|
Let's finish the hierarchical PR first #173 |
|
With some Codex help, merged main into here. Changes seem reasonable + seems to work still. n.b. unlike the NumPyro calls, we can actually |
|
@DanWaxman looks great thanks!
|
Ah, you're right, thanks! I thought we didn't do this for ODE for some reason, I guess we're just slow for discrete-time systems. :) |
There was a problem hiding this comment.
Pull request overview
Adds dedicated ODE/SDE solver backends and expands simulator+filter test coverage to exercise both Diffrax and a faster Euler–Maruyama scan backend for SDE simulation.
Changes:
- Introduces
dynestyx.solvers.solve_odeanddynestyx.solvers.solve_sde(with"diffrax"and"em_scan"sources for SDEs). - Updates
SDESimulatorto dispatch SDE integration via the new solver backend and exposes asourceselector. - Expands test parametrization to run key simulator/filter shape and smoke tests against both SDE backends.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
dynestyx/simulators.py |
Routes ODE/SDE integration through new solver backends; adds SDESimulator(source=...). |
dynestyx/solvers/sde.py |
New SDE backend module with Diffrax and Euler–Maruyama (lax.scan) implementations. |
dynestyx/solvers/odes.py |
New ODE backend module wrapping Diffrax ODE solves. |
dynestyx/solvers/__init__.py |
Exposes solve_ode / solve_sde public API. |
tests/test_predictive_filter_simulator_shapes.py |
Parametrizes SDE shape tests across "diffrax" and "em_scan". |
tests/test_hierarchical_simulator_discretizer_smokes.py |
Parametrizes plated SDE smoke test across both SDE sources. |
tests/test_filter_simulator.py |
Parametrizes SDE simulator+filter integration tests across both SDE sources. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if key is None: | ||
| raise ValueError("PRNG key is required for em_scan SDE solves.") | ||
| if dt0 <= 0: | ||
| raise ValueError(f"EM scan requires dt0 > 0, got dt0={dt0}.") |
There was a problem hiding this comment.
_solve_sde_scan raises when key is None even if no integration would be performed (e.g., all saveat_times <= t0, where the output is just x0 repeated). For parity with the diffrax backend (which early-returns without requiring a key), consider moving the key is None check to only the code path that actually advances time, or adding an early-return branch before validating the key.
| assert self.tol_vbt < dt0, ( | ||
| "tol_vbt must be smaller than dt0 for statistically correct simulation." | ||
| ) | ||
| if self.source not in {"diffrax", "em_scan"}: | ||
| raise ValueError( |
There was a problem hiding this comment.
SDESimulator enforces tol_vbt < dt0 unconditionally, but tol_vbt is only relevant for the source="diffrax" backend. With source="em_scan", this assertion can fail even though the value is unused. Consider guarding the assertion (and any tol_vbt-specific validation) behind if self.source == "diffrax", and optionally ignoring/forbidding non-None tol_vbt when source="em_scan" to avoid confusing configuration errors.
| if not isinstance(dt0_setting, (int, float)): | ||
| raise ValueError( | ||
| "solve_sde(source='em_scan') requires a numeric dt0 in " | ||
| "diffeqsolve_settings." | ||
| ) | ||
| return _solve_sde_scan( | ||
| dynamics, | ||
| t0, | ||
| saveat_times, | ||
| x0, | ||
| control_path_eval, | ||
| float(dt0_setting), |
There was a problem hiding this comment.
The em_scan branch rejects dt0 unless it is a Python int/float. This will raise for common scalar types like numpy.float64 (and potentially 0-d JAX arrays), even though float(dt0_setting) would work. Consider accepting any real scalar by attempting float(dt0_setting) in a try/except (or checking against numbers.Real / np.floating) so solve_sde(source="em_scan") is robust to scalar-like inputs.
| if not isinstance(dt0_setting, (int, float)): | |
| raise ValueError( | |
| "solve_sde(source='em_scan') requires a numeric dt0 in " | |
| "diffeqsolve_settings." | |
| ) | |
| return _solve_sde_scan( | |
| dynamics, | |
| t0, | |
| saveat_times, | |
| x0, | |
| control_path_eval, | |
| float(dt0_setting), | |
| try: | |
| dt0 = float(dt0_setting) | |
| except (TypeError, ValueError) as exc: | |
| raise ValueError( | |
| "solve_sde(source='em_scan') requires a numeric dt0 in " | |
| "diffeqsolve_settings." | |
| ) from exc | |
| return _solve_sde_scan( | |
| dynamics, | |
| t0, | |
| saveat_times, | |
| x0, | |
| control_path_eval, | |
| dt0, |
There was a problem hiding this comment.
Done---introduced a _coerce_dt utility.
DanWaxman
left a comment
There was a problem hiding this comment.
Minor comments only. Other remaining thing, we should probably have a test that compares statistics of our hand-written Euler-Maruyama, on a non-trivial SDE with controls, to the results of diffrax using the Euler solver and a sufficiently small tol_vbt parameter.
| def _coerce_dt(dt: object, *, name: str = "dt") -> float: | ||
| """Coerce a numeric scalar timestep-like value to Python float.""" | ||
| if isinstance(dt, bool): | ||
| raise ValueError( | ||
| f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), " | ||
| f"got {type(dt)}." | ||
| ) | ||
| if isinstance(dt, Real): | ||
| return float(dt) | ||
| if ( | ||
| isinstance(dt, Array) | ||
| and dt.ndim == 0 | ||
| and not jnp.issubdtype(dt.dtype, jnp.bool_) | ||
| and ( | ||
| jnp.issubdtype(dt.dtype, jnp.integer) | ||
| or jnp.issubdtype(dt.dtype, jnp.floating) | ||
| ) | ||
| ): | ||
| return float(dt) | ||
| raise ValueError( | ||
| f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), " | ||
| f"got {type(dt)}." | ||
| ) |
There was a problem hiding this comment.
I kinda don't understand why we need this. If we need a float, we probably need a float everywhere and should require that from the beginning.
There was a problem hiding this comment.
I think the main use case is when it is given as a jax.Array (perhaps due to some batching needs).
The bool/Real checks are definitely a bit much, want me to remove them?
There was a problem hiding this comment.
As discussed, I guess my main objection is that we are concretizing a jax array in the first place. I think this will break things if traced. If we can at all (i) store as a float from the beginning, and only convert to an array as-needed, or (ii) always store as a jax array and deal with the consequences of that, I think the code will be more robust.
There was a problem hiding this comment.
I looked more closely, and I think the main reason we have floats anywhere is for friendly user API (no one wants to write dt0=jnp.array(0.01) ).
What do you think about doing float -> Array at a high-level early on and then do our internals all on Arrays?
Apply suggestion from @DanWaxman Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
mattlevine22
left a comment
There was a problem hiding this comment.
Fixed the 1 minor correction and responded to use case of _coerce_dt
| def _coerce_dt(dt: object, *, name: str = "dt") -> float: | ||
| """Coerce a numeric scalar timestep-like value to Python float.""" | ||
| if isinstance(dt, bool): | ||
| raise ValueError( | ||
| f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), " | ||
| f"got {type(dt)}." | ||
| ) | ||
| if isinstance(dt, Real): | ||
| return float(dt) | ||
| if ( | ||
| isinstance(dt, Array) | ||
| and dt.ndim == 0 | ||
| and not jnp.issubdtype(dt.dtype, jnp.bool_) | ||
| and ( | ||
| jnp.issubdtype(dt.dtype, jnp.integer) | ||
| or jnp.issubdtype(dt.dtype, jnp.floating) | ||
| ) | ||
| ): | ||
| return float(dt) | ||
| raise ValueError( | ||
| f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), " | ||
| f"got {type(dt)}." | ||
| ) |
There was a problem hiding this comment.
I think the main use case is when it is given as a jax.Array (perhaps due to some batching needs).
The bool/Real checks are definitely a bit much, want me to remove them?
|
I'd be happy with conversion at init :).
…On Tue, Apr 21, 2026 at 3:46 PM Matt Levine ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In dynestyx/utils.py
<#178?email_source=notifications&email_token=AC2DPON4AGQNCRSQSWCXY2D4W7FYRA5CNFSNUABKM5UWIORPF5TWS5BNNB2WEL2QOVWGYUTFOF2WK43UKJSXM2LFO4XTIMJVGAZTQMBWGQ22M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJL3QOJPXEZLWNFSXOX3DNRUWG2Y#discussion_r3119935099>
:
> +def _coerce_dt(dt: object, *, name: str = "dt") -> float:
+ """Coerce a numeric scalar timestep-like value to Python float."""
+ if isinstance(dt, bool):
+ raise ValueError(
+ f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
+ f"got {type(dt)}."
+ )
+ if isinstance(dt, Real):
+ return float(dt)
+ if (
+ isinstance(dt, Array)
+ and dt.ndim == 0
+ and not jnp.issubdtype(dt.dtype, jnp.bool_)
+ and (
+ jnp.issubdtype(dt.dtype, jnp.integer)
+ or jnp.issubdtype(dt.dtype, jnp.floating)
+ )
+ ):
+ return float(dt)
+ raise ValueError(
+ f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
+ f"got {type(dt)}."
+ )
I looked more closely, and I think the main reason we have floats
*anywhere* is for friendly user API (no one wants to write
dt0=jnp.array(0.01) ).
What do you think about doing float -> Array at a high-level early on and
then do our internals all on Arrays?
—
Reply to this email directly, view it on GitHub
<#178?email_source=notifications&email_token=AC2DPOKMCTZPSIISZU5BTY34W7FYRA5CNFSNUABKM5UWIORPF5TWS5BNNB2WEL2QOVWGYUTFOF2WK43UKJSXM2LFO4XTIMJVGAZTQMBWGQ22M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJPLQOJPXEZLWNFSXOX3ON52GSZTJMNQXI2LPNZZV6Y3MNFRWW#discussion_r3119935099>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AC2DPOJ2FOSM4VWCJTNND534W7FYRAVCNFSM6AAAAACXBE7HN2VHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHM2DCNJQGM4DANRUGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior.
Made-with: Cursor