Skip to content
Merged
18 changes: 9 additions & 9 deletions docs/tutorials/gentle_intro/06_continuous_time.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dynestyx/models/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _infer_observation_dim_in_plate_context(
observation_model: Callable[[State, Control | None, Time], Any],
inferred_state_dim: int,
control_dim: int,
t0: float | None,
t0: Time | None,
observation_dim: int | None,
) -> int:
"""Infer observation dimension in plate context, falling back to explicit value."""
Expand All @@ -250,7 +250,7 @@ def _infer_observation_dim_in_plate_context(
state_dim=inferred_state_dim,
)
u0 = None if control_dim == 0 else jnp.zeros((control_dim,))
dummy_t0 = jnp.array(0.0) if t0 is None else jnp.array(t0)
dummy_t0 = jnp.array(0.0) if t0 is None else t0
try:
obs_dist = observation_model(x0, u0, dummy_t0)
return int(
Expand Down
16 changes: 8 additions & 8 deletions dynestyx/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_validate_state_dim,
_validate_state_evolution_output_shape,
)
from dynestyx.types import Control, State, Time, dState
from dynestyx.types import Control, State, Time, TimeLike, as_scalar_time_array, dState


class DynamicalModel(eqx.Module):
Expand Down Expand Up @@ -62,7 +62,7 @@ class DynamicalModel(eqx.Module):
A callable is accepted (e.g., `lambda x, u, t: ...`) as long as it returns a NumPyro-compatible
distribution, while subclassing `ObservationModel` is recommended for richer reuse and consistency.
control_model (Any): Optional model for control inputs (e.g., exogenous process). Not currently supported.
t0 (float | None): Optional declared start time of the model. If ``None`` (default), the start time
t0 (float | Array | None): Optional declared start time of the model. If ``None`` (default), the start time
is auto-inferred as ``obs_times[0]`` when the simulator runs and recorded as a
``numpyro.deterministic("t0", ...)`` site. If provided, it must match ``obs_times[0]``
exactly; a mismatch raises a ``ValueError`` at simulation time.
Expand All @@ -84,7 +84,7 @@ class DynamicalModel(eqx.Module):
observation_model: Callable[[State, Control, Time], Distribution]
control_dim: int
control_model: Any
t0: float | None
t0: Time | None
state_dim: int
observation_dim: int
categorical_state: bool
Expand All @@ -98,7 +98,7 @@ def __init__(
control_dim: int | None = None,
control_model=None,
*,
t0: float | None = None,
t0: TimeLike | None = None,
state_dim: int | None = None,
observation_dim: int | None = None,
categorical_state: bool | None = None,
Expand All @@ -113,7 +113,7 @@ def __init__(
self.state_evolution = state_evolution
self.observation_model = observation_model
self.control_model = control_model
self.t0 = t0
self.t0 = None if t0 is None else as_scalar_time_array(t0, name="t0")

inferred_state_dim = _infer_vector_dim_from_distribution(
initial_condition, "initial_condition"
Expand All @@ -136,7 +136,7 @@ def __init__(
observation_model=observation_model,
inferred_state_dim=inferred_state_dim,
control_dim=control_dim,
t0=t0,
t0=self.t0,
observation_dim=observation_dim,
)
self.state_dim = int(inferred_state_dim)
Expand All @@ -150,7 +150,7 @@ def __init__(
):
x0 = jnp.zeros((inferred_state_dim,))
u0 = None if control_dim == 0 else jnp.zeros((control_dim,))
dummy_t0 = jnp.array(0.0) if t0 is None else jnp.array(t0)
dummy_t0 = jnp.array(0.0) if self.t0 is None else self.t0
inferred_bm_dim = _infer_bm_dim(
state_evolution, inferred_state_dim, x0, u0, dummy_t0
)
Expand All @@ -170,7 +170,7 @@ def __init__(
initial_condition=initial_condition, state_dim=inferred_state_dim
)
u0 = None if control_dim == 0 else jnp.zeros((control_dim,))
dummy_t0 = jnp.array(0.0) if t0 is None else jnp.array(t0)
dummy_t0 = jnp.array(0.0) if self.t0 is None else self.t0

inferred_bm_dim = _validate_state_evolution_output_shape(
state_evolution=state_evolution,
Expand Down
137 changes: 50 additions & 87 deletions dynestyx/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from collections.abc import Callable
from contextlib import contextmanager
from typing import cast
from typing import Literal, cast

import diffrax as dfx
import equinox as eqx
Expand All @@ -15,7 +15,7 @@
import numpyro
from effectful.ops.semantics import fwd
from effectful.ops.syntax import ObjectInterpretation, implements
from jax import Array, lax
from jax import Array
from numpyro.contrib.control_flow import scan as nscan

from dynestyx.handlers import HandlesSelf, _sample_intp
Expand All @@ -26,7 +26,8 @@
DiscreteTimeStateEvolution,
DynamicalModel,
)
from dynestyx.types import FunctionOfTime, State
from dynestyx.solvers import solve_ode, solve_sde
from dynestyx.types import FunctionOfTime, State, Time, TimeLike, as_scalar_time_array
from dynestyx.utils import (
_array_has_plate_dims,
_build_control_path,
Expand Down Expand Up @@ -587,69 +588,6 @@ def _simulate(
raise NotImplementedError()


def _solve_de(
dynamics,
t0: float,
saveat_times: Array,
x0: State,
control_path_eval: Callable[[Array], Array | None],
diffeqsolve_settings: dict,
*,
key=None,
tol_vbt: float | None = None,
) -> Array:
"""Solve one ODE/SDE trajectory with diffrax.

Uses ODE mode when diffusion is None, otherwise SDE mode. `t0` is explicit
so rollout segments can start from filtered times.
"""
t1 = saveat_times[-1]

# Keep the branch JAX-traceable when t0/t1 are traced.
def _early_return():
return jnp.broadcast_to(x0, (len(saveat_times),) + jnp.shape(x0))

def _solve():
diffusion = dynamics.state_evolution.diffusion_coefficient

def _drift(t, y, args):
u_t = args(t) if args is not None else None
return dynamics.state_evolution.total_drift(x=y, u=u_t, t=t)

if diffusion is None:
terms = dfx.ODETerm(_drift)
else:
k_bm, _ = jr.split(key, 2)
bm = dfx.VirtualBrownianTree(
t0=t0,
t1=t1,
tol=tol_vbt,
shape=(dynamics.state_evolution.bm_dim,),
key=k_bm,
)

def _diffusion(t, y, args):
u_t = args(t) if args is not None else None
return dynamics.state_evolution.diffusion_coefficient(x=y, u=u_t, t=t)

terms = dfx.MultiTerm( # type: ignore
dfx.ODETerm(_drift), dfx.ControlTerm(_diffusion, bm)
)

sol = dfx.diffeqsolve(
terms,
t0=t0,
t1=t1,
y0=x0,
saveat=dfx.SaveAt(ts=saveat_times),
args=control_path_eval,
**diffeqsolve_settings,
)
return sol.ys

return lax.cond(t0 >= t1, _early_return, _solve)


def _emit_observations(
name: str,
dynamics,
Expand Down Expand Up @@ -716,17 +654,24 @@ class SDESimulator(BaseSimulator):
very high-dimensional latent path and is usually a **poor inference
strategy** for parameters. Prefer filtering (`Filter` with
`ContinuousTime*Config`) or particle methods instead.

Tip for speed:
- Use `source="em_scan"` if you are happy with a simple Euler-Maruyama forward simulation
(10–20x faster than Diffrax's implementation; see
[Diffrax Issue #517](https://github.com/patrick-kidger/diffrax/issues/517)).
- Use `source="diffrax"` if you want greater flexibility in the solver and step-size control.
"""

def __init__(
self,
solver: dfx.AbstractSolver = dfx.Heun(),
stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(),
dt0: float = 1e-4,
dt0: TimeLike = 1e-4,
tol_vbt: float | None = None,
max_steps: int | None = None,
n_simulations: int = 1,
source: Literal["diffrax", "em_scan"] = "diffrax",
):
"""Configure SDE integration settings.

Expand All @@ -739,7 +684,7 @@ def __init__(
adjoint: Diffrax adjoint strategy used for differentiation through the
solver (relevant when used under gradient-based inference). See
[Adjoints](https://docs.kidger.site/diffrax/api/adjoints/).
dt0: Initial step size passed to
dt0: Initial step size (float or JAX array) passed to
[`diffrax.diffeqsolve`](https://docs.kidger.site/diffrax/api/diffeqsolve/).
tol_vbt: Tolerance parameter for
[`diffrax.VirtualBrownianTree`](https://docs.kidger.site/diffrax/api/brownian/). If None,
Expand All @@ -748,28 +693,44 @@ def __init__(
max_steps: Optional hard cap on solver steps.
n_simulations: Number of independent trajectory simulations. When > 1,
states and observations have an extra leading dimension (n_simulations, T, ...).
source: SDE backend to use. `"diffrax"` uses Diffrax + Brownian tree.
`"em_scan"` uses a custom fixed-step Euler-Maruyama `lax.scan`
that advances at every `dt0` tick and also lands exactly on all
requested solve times.

Notes:
- `VirtualBrownianTree` draws randomness via `numpyro.prng_key()`, so
`SDESimulator` must be executed inside a seeded NumPyro context.
"""
dt0_arr = as_scalar_time_array(dt0, name="dt0")
self.diffeqsolve_settings = {
"solver": solver,
"stepsize_controller": stepsize_controller,
"adjoint": adjoint,
"dt0": dt0,
"dt0": dt0_arr,
"max_steps": max_steps,
}
self.n_simulations = n_simulations
self.source = source
if self.source not in {"diffrax", "em_scan"}:
raise ValueError(
"SDESimulator source must be one of {'diffrax', 'em_scan'}, "
f"got source={self.source!r}."
)

if tol_vbt is None:
self.tol_vbt = dt0 / 2.0
else:
self.tol_vbt = tol_vbt
self.tol_vbt: Time | None
if self.source == "diffrax":
if tol_vbt is None:
self.tol_vbt = dt0_arr / 2.0
else:
self.tol_vbt = as_scalar_time_array(tol_vbt, name="tol_vbt")

assert self.tol_vbt < dt0, (
"tol_vbt must be smaller than dt0 for statistically correct simulation."
)
assert self.tol_vbt < dt0_arr, (
"tol_vbt must be smaller than dt0 for statistically correct simulation."
)
else:
# tol_vbt is only used by the diffrax backend.
self.tol_vbt = None

def _simulate(
self,
Expand Down Expand Up @@ -858,13 +819,14 @@ def _simulate(
def _sim_one_trajectory(key: Array, x0: Array) -> tuple[Array, Array]:
"""Simulate one SDE trajectory and its emissions."""
k_solve, k_obs = jr.split(key, 2)
states_sol = _solve_de(
dynamics,
t0,
times,
x0,
control_path_eval,
self.diffeqsolve_settings,
states_sol = solve_sde(
source=self.source,
dynamics=dynamics,
t0=t0,
saveat_times=times,
x0=x0,
control_path_eval=control_path_eval,
diffeqsolve_settings=self.diffeqsolve_settings,
key=k_solve,
tol_vbt=self.tol_vbt,
)
Expand Down Expand Up @@ -1173,7 +1135,7 @@ def __init__(
solver: dfx.AbstractSolver = dfx.Tsit5(),
adjoint: dfx.AbstractAdjoint = dfx.RecursiveCheckpointAdjoint(),
stepsize_controller: dfx.AbstractStepSizeController = dfx.ConstantStepSize(),
dt0: float = 1e-3,
dt0: TimeLike = 1e-3,
max_steps: int = 100_000,
n_simulations: int = 1,
):
Expand All @@ -1187,17 +1149,18 @@ def __init__(
See [Adjoints](https://docs.kidger.site/diffrax/api/adjoints/).
stepsize_controller: Diffrax step-size controller (default:
[`dfx.ConstantStepSize`](https://docs.kidger.site/diffrax/api/stepsize_controller/)).
dt0: Initial step size passed to
dt0: Initial step size (float or JAX array) passed to
[`diffrax.diffeqsolve`](https://docs.kidger.site/diffrax/api/diffeqsolve/).
max_steps: Hard cap on solver steps.
n_simulations: Number of independent trajectory simulations. When > 1,
states and observations have shape (n_simulations, T, ...).
"""
dt0_arr = as_scalar_time_array(dt0, name="dt0")
self.diffeqsolve_settings = {
"solver": solver,
"stepsize_controller": stepsize_controller,
"adjoint": adjoint,
"dt0": dt0,
"dt0": dt0_arr,
"max_steps": max_steps,
}
self.n_simulations = n_simulations
Expand Down Expand Up @@ -1251,7 +1214,7 @@ def _simulate(

def _sim_one_trajectory(x0: Array, *, obs_key=None):
"""Simulate one ODE trajectory and emit observations."""
states = _solve_de(
states = solve_ode(
dynamics,
t0,
times,
Expand Down
6 changes: 6 additions & 0 deletions dynestyx/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Numerical solver backends for dynestyx simulators."""

from .odes import solve_ode
from .sde import solve_sde

__all__ = ["solve_ode", "solve_sde"]
46 changes: 46 additions & 0 deletions dynestyx/solvers/odes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""ODE solver backend for simulators."""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import diffrax as dfx
import jax.numpy as jnp
from jax import Array, lax

from dynestyx.types import State, TimeLike, as_scalar_time_array


def solve_ode(
dynamics: Any,
t0: TimeLike,
saveat_times: Array,
x0: State,
control_path_eval: Callable[[Array], Array | None],
diffeqsolve_settings: dict[str, Any],
) -> Array:
"""Solve one ODE trajectory with Diffrax and save at requested times."""
t0_arr = as_scalar_time_array(t0, name="t0", dtype=saveat_times.dtype)
t1 = saveat_times[-1]

def _early_return() -> Array:
return jnp.broadcast_to(x0, (len(saveat_times),) + jnp.shape(x0))

def _solve() -> Array:
def _drift(t, y, args):
u_t = args(t) if args is not None else None
return dynamics.state_evolution.total_drift(x=y, u=u_t, t=t)

sol = dfx.diffeqsolve(
dfx.ODETerm(_drift),
t0=t0_arr,
t1=t1,
y0=x0,
saveat=dfx.SaveAt(ts=saveat_times),
args=control_path_eval,
**diffeqsolve_settings,
)
return sol.ys

return lax.cond(t0_arr >= t1, _early_return, _solve)
Loading
Loading