Skip to content

Add fast SDE solver backends and filter-simulator coverage.#178

Merged
DanWaxman merged 10 commits intomainfrom
ml-fast-sde-solver2
Apr 22, 2026
Merged

Add fast SDE solver backends and filter-simulator coverage.#178
DanWaxman merged 10 commits intomainfrom
ml-fast-sde-solver2

Conversation

@mattlevine22
Copy link
Copy Markdown
Collaborator

Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior.

Made-with: Cursor

Introduce solver modules for ODE and SDE paths, add an Euler-Maruyama scan source for SDESimulator, and update filter+simulator tests and continuous-time tutorial outputs to cover the new backend behavior.

Made-with: Cursor
@mattlevine22
Copy link
Copy Markdown
Collaborator Author

This PR seems ready, and I think its tests are actually passing because main is failing right now.

@mattlevine22 mattlevine22 marked this pull request as ready for review March 27, 2026 15:44
@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Let's finish the hierarchical PR first #173

@DanWaxman
Copy link
Copy Markdown
Collaborator

With some Codex help, merged main into here. Changes seem reasonable + seems to work still.

n.b. unlike the NumPyro calls, we can actually vmap over these scan calls instead of doing everything in a big for loop. Worth considering.

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

@DanWaxman looks great thanks!

  • I added a speed tip to the SDESimulator docs
  • Regarding vmap vs for loop over the scans...I believe we already vmap; SDESimulator line 838 has states, emissions = jax.vmap(_sim_one_trajectory)(...)

@DanWaxman
Copy link
Copy Markdown
Collaborator

  • Regarding vmap vs for loop over the scans...I believe we already vmap; SDESimulator line 838 has states, emissions = jax.vmap(_sim_one_trajectory)(...)

Ah, you're right, thanks! I thought we didn't do this for ODE for some reason, I guess we're just slow for discrete-time systems. :)

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds dedicated ODE/SDE solver backends and expands simulator+filter test coverage to exercise both Diffrax and a faster Euler–Maruyama scan backend for SDE simulation.

Changes:

  • Introduces dynestyx.solvers.solve_ode and dynestyx.solvers.solve_sde (with "diffrax" and "em_scan" sources for SDEs).
  • Updates SDESimulator to dispatch SDE integration via the new solver backend and exposes a source selector.
  • Expands test parametrization to run key simulator/filter shape and smoke tests against both SDE backends.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
dynestyx/simulators.py Routes ODE/SDE integration through new solver backends; adds SDESimulator(source=...).
dynestyx/solvers/sde.py New SDE backend module with Diffrax and Euler–Maruyama (lax.scan) implementations.
dynestyx/solvers/odes.py New ODE backend module wrapping Diffrax ODE solves.
dynestyx/solvers/__init__.py Exposes solve_ode / solve_sde public API.
tests/test_predictive_filter_simulator_shapes.py Parametrizes SDE shape tests across "diffrax" and "em_scan".
tests/test_hierarchical_simulator_discretizer_smokes.py Parametrizes plated SDE smoke test across both SDE sources.
tests/test_filter_simulator.py Parametrizes SDE simulator+filter integration tests across both SDE sources.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread dynestyx/solvers/sde.py Outdated
Comment on lines +38 to +41
if key is None:
raise ValueError("PRNG key is required for em_scan SDE solves.")
if dt0 <= 0:
raise ValueError(f"EM scan requires dt0 > 0, got dt0={dt0}.")
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_solve_sde_scan raises when key is None even if no integration would be performed (e.g., all saveat_times <= t0, where the output is just x0 repeated). For parity with the diffrax backend (which early-returns without requiring a key), consider moving the key is None check to only the code path that actually advances time, or adding an early-return branch before validating the key.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Resulted in a larger refactor to not duplicate early_return. This will need further refactor to support #199, but I suggest we complete this PR as is and shift the subsequent load onto #199.

Comment thread dynestyx/simulators.py Outdated
Comment on lines +720 to +724
assert self.tol_vbt < dt0, (
"tol_vbt must be smaller than dt0 for statistically correct simulation."
)
if self.source not in {"diffrax", "em_scan"}:
raise ValueError(
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDESimulator enforces tol_vbt < dt0 unconditionally, but tol_vbt is only relevant for the source="diffrax" backend. With source="em_scan", this assertion can fail even though the value is unused. Consider guarding the assertion (and any tol_vbt-specific validation) behind if self.source == "diffrax", and optionally ignoring/forbidding non-None tol_vbt when source="em_scan" to avoid confusing configuration errors.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread dynestyx/solvers/sde.py Outdated
Comment on lines +177 to +188
if not isinstance(dt0_setting, (int, float)):
raise ValueError(
"solve_sde(source='em_scan') requires a numeric dt0 in "
"diffeqsolve_settings."
)
return _solve_sde_scan(
dynamics,
t0,
saveat_times,
x0,
control_path_eval,
float(dt0_setting),
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The em_scan branch rejects dt0 unless it is a Python int/float. This will raise for common scalar types like numpy.float64 (and potentially 0-d JAX arrays), even though float(dt0_setting) would work. Consider accepting any real scalar by attempting float(dt0_setting) in a try/except (or checking against numbers.Real / np.floating) so solve_sde(source="em_scan") is robust to scalar-like inputs.

Suggested change
if not isinstance(dt0_setting, (int, float)):
raise ValueError(
"solve_sde(source='em_scan') requires a numeric dt0 in "
"diffeqsolve_settings."
)
return _solve_sde_scan(
dynamics,
t0,
saveat_times,
x0,
control_path_eval,
float(dt0_setting),
try:
dt0 = float(dt0_setting)
except (TypeError, ValueError) as exc:
raise ValueError(
"solve_sde(source='em_scan') requires a numeric dt0 in "
"diffeqsolve_settings."
) from exc
return _solve_sde_scan(
dynamics,
t0,
saveat_times,
x0,
control_path_eval,
dt0,

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done---introduced a _coerce_dt utility.

Comment thread dynestyx/solvers/sde.py Outdated
Comment thread dynestyx/solvers/sde.py Outdated
Comment thread dynestyx/solvers/sde.py Outdated
Comment thread dynestyx/solvers/sde.py
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments only. Other remaining thing, we should probably have a test that compares statistics of our hand-written Euler-Maruyama, on a non-trivial SDE with controls, to the results of diffrax using the Euler solver and a sufficiently small tol_vbt parameter.

Comment thread dynestyx/solvers/sde.py Outdated
Comment thread dynestyx/utils.py Outdated
Comment on lines +164 to +186
def _coerce_dt(dt: object, *, name: str = "dt") -> float:
"""Coerce a numeric scalar timestep-like value to Python float."""
if isinstance(dt, bool):
raise ValueError(
f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
f"got {type(dt)}."
)
if isinstance(dt, Real):
return float(dt)
if (
isinstance(dt, Array)
and dt.ndim == 0
and not jnp.issubdtype(dt.dtype, jnp.bool_)
and (
jnp.issubdtype(dt.dtype, jnp.integer)
or jnp.issubdtype(dt.dtype, jnp.floating)
)
):
return float(dt)
raise ValueError(
f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
f"got {type(dt)}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda don't understand why we need this. If we need a float, we probably need a float everywhere and should require that from the beginning.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main use case is when it is given as a jax.Array (perhaps due to some batching needs).

The bool/Real checks are definitely a bit much, want me to remove them?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I guess my main objection is that we are concretizing a jax array in the first place. I think this will break things if traced. If we can at all (i) store as a float from the beginning, and only convert to an array as-needed, or (ii) always store as a jax array and deal with the consequences of that, I think the code will be more robust.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked more closely, and I think the main reason we have floats anywhere is for friendly user API (no one wants to write dt0=jnp.array(0.01) ).

What do you think about doing float -> Array at a high-level early on and then do our internals all on Arrays?

mattlevine22 and others added 2 commits April 21, 2026 14:44
Apply suggestion from @DanWaxman

Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
Copy link
Copy Markdown
Collaborator Author

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the 1 minor correction and responded to use case of _coerce_dt

Comment thread dynestyx/utils.py Outdated
Comment on lines +164 to +186
def _coerce_dt(dt: object, *, name: str = "dt") -> float:
"""Coerce a numeric scalar timestep-like value to Python float."""
if isinstance(dt, bool):
raise ValueError(
f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
f"got {type(dt)}."
)
if isinstance(dt, Real):
return float(dt)
if (
isinstance(dt, Array)
and dt.ndim == 0
and not jnp.issubdtype(dt.dtype, jnp.bool_)
and (
jnp.issubdtype(dt.dtype, jnp.integer)
or jnp.issubdtype(dt.dtype, jnp.floating)
)
):
return float(dt)
raise ValueError(
f"{name} must be a numeric scalar (Python/NumPy real or scalar JAX array), "
f"got {type(dt)}."
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main use case is when it is given as a jax.Array (perhaps due to some batching needs).

The bool/Real checks are definitely a bit much, want me to remove them?

@DanWaxman
Copy link
Copy Markdown
Collaborator

DanWaxman commented Apr 21, 2026 via email

@DanWaxman DanWaxman self-requested a review April 22, 2026 14:14
@DanWaxman DanWaxman merged commit 1ef0782 into main Apr 22, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants