Skip to content

Pin user-supplied floats to canonical dtype at every API boundary#345

Merged
hmgaudecker merged 136 commits into
mainfrom
feat/canonical-float-dtype
May 11, 2026
Merged

Pin user-supplied floats to canonical dtype at every API boundary#345
hmgaudecker merged 136 commits into
mainfrom
feat/canonical-float-dtype

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 6, 2026

Context

Continues the int-side normalisation in #340 with the float side, though for a completely different reason.

The constraint:

def borrowing_constraint(
    consumption: ContinuousAction,    # action grid, fp32 (quantized via jnp.float32)
    cash_on_hand: FloatND,
    consumption_floor: float,         # Python float — fp64
    equivalence_scale: FloatND,
) -> BoolND:
    return consumption <= cash_on_hand + consumption_floor * equivalence_scale

fired when it should not have -- transfers always ensure that the consumption floor is supported.

What lands on each side without dtype barriers (under jax_enable_x64=True,
which aca_model/__init__.py sets at import):

  • LHS consumption: action grid quantized to jnp.float32 in the
    runtime-consumption-points path. Promoted to fp64 for the comparison —
    but promotion preserves the quantization error, it doesn't undo it.
  • RHS consumption_floor * equivalence_scale: consumption_floor is a
    Python float (fp64 precision), so the RHS keeps fp64 throughout.

When cash_on_hand took large negative values, the two sides differ by less than
the smallest gap fp64 can represent at that magnitude (a fraction of a single
fp32 quantization step, leaked into fp64 by the promotion). <= flips, and
validate_initial_conditions raises InvalidInitialConditionsError.

This was very annoying to debug. To have one less thing to worry about, this PR makes sure all floats have a consistent dtype.

Overview

Adds canonical_float_dtype() and safe_to_float_dtype next to the int
helpers from #340, and applies them at the same boundaries (params,
initial conditions, transition outputs, V-arrays).

What lands

src/lcm/dtypes.py

  • canonical_float_dtype() returns jnp.float64 under
    jax_enable_x64=True, else jnp.float32. Read at call time.
  • safe_to_float_dtype(value, *, name) casts to the canonical dtype
    and raises OverflowError (with the leaf's qualified name) when
    down-casting a value above float32 magnitude. Up-casts and
    same-width casts skip the range check; precision loss within range
    is not an error.

Params boundary (src/lcm/params/processing.py)

Simulate boundary (src/lcm/simulation/initial_conditions.py)

Transition boundary (src/lcm/simulation/transitions.py)

  • _update_states_for_subjects unconditionally casts
    next_state_values to the storage dtype. The cross-kind escape
    hatch added in Model.n_subjects: AOT-compile simulate, lock integer dtype to int32 #340 (so an int-typed user initial condition for a
    continuous state would not be coerced) is no longer needed — the
    initial-state cast above pins storage to the canonical float dtype
    upstream of this site.

Tests

  • tests/test_float_dtype_invariants.py (10 tests): helper round-trips,
    initial-state casts, params casts, grid materialisation, V-array
    dtype, multi-period state-dtype stability.
  • tests/test_dtypes.py: 7 additional float-helper unit tests.
  • tests/test_validate_param_types.py:
    numpy_array_param_rejected -> numpy_array_param_accepted_and_ normalised. With the boundary cast in place numpy arrays are
    auto-converted; the historical rejection-by-isinstance is obsolete.

928 pass, 5 skip; prek + ty clean.

Stacked on

This branch is stacked on feat/simulate-aot-n-subjects (#340) — the
base ref for this PR. Merge order: #340 first, then this. The diff
view here only shows the float-side changes.

Out of scope

hmgaudecker and others added 30 commits April 29, 2026 06:29
Extends the existing runtime-points mechanism (previously state-only)
to continuous action grids. With this change, an action declared as
`IrregSpacedGrid(n_points=N)` adds an `{action_name: {"points":
"Float1D"}}` entry to the regime params template, and `state_action_space()`
substitutes the runtime-supplied points into `continuous_actions` at
solve / simulate time.

Motivation: aca-dev's structural retirement model has a `consumption`
action grid whose lower bound is the per-iteration `consumption_floor`
parameter. Without this change the c-grid bounds would have to be
fixed at build time, which forces either an over-wide grid (wasted
density) or model rebuilds per estimation iteration (recompilation).

Mirrors the existing state-grid treatment:
- `regime_template.py`: walks `regime.actions` alongside `regime.states`,
  factoring the shared shadowing check into helpers.
- `interfaces.InternalRegime.state_action_space()`: builds both
  state and continuous-action replacements in a single sweep over
  `self.grids`, then calls `_base_state_action_space.replace(...)`
  with whichever side actually had substitutions.
- `pandas_utils._is_runtime_grid_param`: also recognises action grids
  so column extraction in `to_dataframe()` keeps working.

Tests (TDD): four new tests in `tests/test_runtime_params.py`,
mirroring the state-grid counterparts — params-template entry,
solve, runtime-vs-fixed equivalence, and a sanity check that
varying runtime points actually changes V.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model now declares `consumption` as `IrregSpacedGrid(n_points=N)`
with runtime-supplied points. The bench builder now passes
`model=model` to `get_benchmark_params` so consumption gridpoints
are injected into params before solving.

aca-model rev: adc8a19 → 4123fe9 (feature/runtime-consumption-points)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`IrregSpacedGrid(n_points=N)` declares a continuous grid whose values
are supplied at runtime via `params[regime][grid_name]['points']`.
Substitution happens inside
`InternalRegime.state_action_space(regime_params=...)` at solve /
simulate time. Any code path that calls `to_jax()` on the base grid
before substitution silently got `jnp.full(N, jnp.nan)` and went on to
compute against the placeholder.

That is exactly what fired in `validate_initial_conditions` for
`task_simulate_aca`: the validator built the action grid by calling
`internal_regime.grids[name].to_jax()` (placeholder NaNs), then asked
`borrowing_constraint(consumption=NaN, wealth=W)` whether each
gridpoint was feasible. NaN comparisons are False, so every action
was reported infeasible for every subject in every initial regime.

Make the invariant explicit: `IrregSpacedGrid.to_jax()` raises
`GridInitializationError` for runtime-supplied grids, with a message
pointing the caller at `state_action_space(regime_params=...)` for
real values or `.n_points` for shape. Confine the legitimate
"placeholder needed for AOT tracing" caller (the base state-action
space) to a private helper in `state_action_space.py` that uses NaN
explicitly. Reroute `_check_regime_feasibility` through the
substituted state-action space.

Add regression tests covering both runtime action and runtime state
grids round-tripping `simulate(check_initial_conditions=True)`, and
unit tests pinning down the new raise + the existing NaN-source
mechanics in `map_coordinates` / `get_irreg_coordinate`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move late `DiscreteGrid`, `map_coordinates`, and `get_irreg_coordinate`
imports to the module top level (PLC0415), drop the unnecessary `val`
assignment before return (RET504), and mark the unused `wealth` arg in
the local `borrow` constraint as `# noqa: ARG001`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A regime function whose output is then re-indexed by a discrete state
inside another consumer (function, constraint, or transition) is a
silent footgun: pylcm broadcasts function outputs to per-cell scalars
before consumption, so the indexing silently produces NaN at runtime
instead of the intended scalar. The aca-baseline benchmark hit this
via `bequest(... utility_scale_factor[pref_type])` where
`utility_scale_factor` is registered as a regime function — the dead
regime's V came back all-NaN with no actionable error.

Adds an AST-walking validator in `validate_logical_consistency` that
inspects every consumer (functions, constraints, transition) for a
`Subscript(Name=X, slice=Name=Y)` pattern where `X` is in
`regime.functions` and `Y` is a `DiscreteGrid` state. If any clash is
found, raises `RegimeInitializationError` listing each clash and
pointing the user at the safe pattern (function takes the state,
returns a scalar — see `discount_factor`).

Three TDD tests in `tests/test_function_output_state_indexing.py`:
- the clash raises (functions case)
- the safe pattern (function takes the state, scalar return) builds
- the check applies to constraints too

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model `feature/runtime-consumption-points` 4123fe9 → 1342861
(refactors `utility_scale_factor` to take `pref_type` and return a
scalar, eliminating the regime-function-output / state-indexed-input
clash that produced NaN in the dead regime's V).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ion space

`create_regime_state_action_space` (used during forward simulation) was
calling `create_state_action_space` directly, which leaves
`pass_points_at_runtime=True` IrregSpacedGrid action grids as their NaN
placeholder. The placeholder fed straight into
`argmax_and_max_Q_over_a` and `_lookup_values_from_indices`, so optimal
actions came back NaN, the source regime's `next_state` propagated NaN
into every target regime's namespaced state, and `validate_V` raised on
the first downstream regime whose utility depended on those states
(the dead regime in aca-model: assets/pref_type both NaN).

Route through `internal_regime.state_action_space(regime_params=...)`
(the same path solve uses) and overlay the per-subject states. Add a
TDD regression test in tests/test_runtime_params.py covering the
simulate path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…grid values

`LogSpacedGrid` previously inherited only the generic continuous-grid
checks (start < stop, n_points > 0). With `start <= 0`, `to_jax()`
silently returned NaN/-inf, and the bug would only surface deep
inside an interpolation kernel. Now refuses at construction.

While here, tighten two adjacent silent-failure modes:

- `_validate_continuous_grid` rejects non-finite `start`/`stop`.
  `start >= stop` is False for NaN, so a NaN bound previously slipped
  through every check.
- `_validate_irreg_spaced_grid` rejects non-finite points. The
  ascending-order test uses `>=`, which is False for NaN, so a NaN
  point previously passed the order check silently.

Both matter for runtime-supplied grids: e.g. `geomspace(consumption_floor,
MAX, N)` with a bad `consumption_floor` produces all-NaN points, and we
want that caught at the grid layer rather than as a downstream V_arr
NaN diagnostic.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… banners

- tests/test_single_feasible_action.py: drop three decorative section
  banners (AGENTS.md prohibits `# ---...---` separators); fold the
  banner prose into the docstrings of the tests/helpers below.
- tests/test_single_feasible_action.py: type-annotate `_crra_bequest`
  and `_alive_utility`'s pref_type / consumption_weight /
  coefficient_rra arguments (DiscreteState / FloatND).
- tests/test_runtime_params.py: type-annotate `_make_action_grid_model`
  and `_make_action_grid_model_with_stateful_dead`.
- src/lcm/simulation/transitions.py: re-run `_validate_all_states_present`
  in the new `create_regime_state_action_space` (the substitution
  switch from `create_state_action_space(states=...)` to
  `base.replace(states=...)` had silently dropped this check).
- src/lcm/params/regime_template.py: docstring on
  `_fail_if_runtime_grid_shadows_function`; fix stale phrasing in
  `create_regime_params_template` ("matching the state name" →
  "matching the state or action name").
- src/lcm/interfaces.py: comment why the `_ShockGrid` substitution
  branch is gated on `in_states` only (state-only by design,
  AGENTS.md forbids ShockGrids as actions; gate is the explicit
  enforcement of that invariant).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The validator's error message already explains why; the class docstring
only needs the contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rid path

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Derived categoricals (`regime.derived_categoricals`, function outputs
that pylcm treats as categoricals — see
https://pylcm.readthedocs.io/en/latest/pandas-interop/#derived-categoricals)
suffer the same per-cell broadcast clash as discrete states. Extend
`discrete_state_names` in `_validate_function_output_state_indexing`
to include them; add a TDD test.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…module)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
pylcm is a general library; references to a particular companion
application become stale fast and force readers to know unrelated
projects to follow the test rationale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The variable previously named `discrete_state_names` accumulated state
DiscreteGrids, derived categoricals, and now discrete actions — all
three suffer the same per-cell broadcast clash when a consumer does
`func_output[X]`. Renamed the variable, the two helpers
(`_validate_function_output_grid_indexing`,
`_find_function_output_grid_indexing`), the test module
(`test_function_output_grid_indexing.py`), and the error-message
wording ("discrete state" → "discrete grid"). Added a TDD test for
the discrete-action case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion

The previous docstring claimed the indexing 'silently produces NaN', but a
disabled-validator probe shows otherwise:

- When the producer takes the discrete grid as input, its output is a
  per-cell scalar; `func_output[grid]` raises `IndexError: Too many indices`
  at trace time. This is the real footgun the validator should catch.
- When the producer does NOT take the discrete grid as input, its output
  stays array-shaped and `func_output[grid]` is correct code that solves
  to sensible V values.

The previous validator flagged both shapes — including the safe one — as a
clash. Tighten: only fire when the producing function also takes the
discrete grid as input. Update the description to match observed behaviour
(IndexError, not NaN). Add a regression test that exercises the
array-valued-producer + state-indexed-consumer shape and asserts it builds
without raising.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #334 introduced a deferred-diagnostics accumulator that appends every
(regime, period) NaN/Inf flag to a Python list, stacks the lists at end
of solve, and `.tolist()`s the stacks to host. On a 16 GB V100 at
production aca-baseline grid sizes the stacked reduction graph holds the
per-period `isnan(V_arr)` / `isinf(V_arr)` intermediates alive
simultaneously; the post-loop `.tolist()` then asks XLA to compile the
fan-in and OOMs on a ~7.3 GiB allocation on top of the already-resident
solution V arrays. Symptom: backward induction reports every age as
"finished in ~14 ms" (dispatch-async times), then
`JaxRuntimeError: RESOURCE_EXHAUSTED` at the first `.tolist()`.

Fix: replace the per-period list-append with a running scalar OR; add a
per-period `block_until_ready()` so each period's reduction kernel
finishes (and its intermediate is freed) before the next period
dispatches. `block_until_ready` is device-only — no host transfer, no
PCIe round-trip — so it doesn't reintroduce the per-period sync that
#334 removed; in practice the small reduction has finished by the time
`max_Q_over_a` (~14 ms/period) returns.

End of solve: one `.item()` per running scalar. On a healthy solve those
two bools are False and we return without materialising any per-row
state. Failure paths (`running_any_nan` / `running_any_inf` True) walk
`diagnostic_rows` and materialise one bool per row to localise the
offender — same total host transfers as the prior code, but only on the
failure path.

Debug-stats path (`log_level="debug"`) still appends min/max/mean per
period; a single per-period `block_until_ready` after the appends frees
those intermediates too. The end-of-solve `_log_per_period_stats` keeps
the existing per-(regime, period) log line.

`_StackedReductions`, `_emit_deferred_diagnostics`, and the old
`_raise_if_nan` / `_warn_if_inf` (taking pre-materialised flag lists)
are replaced by `_emit_post_loop_diagnostics` (orchestrator),
`_raise_first_nan_row`, `_warn_inf_rows`, and `_log_per_period_stats`.

Tests: new `tests/solution/test_diagnostics.py` covering the four log
levels — happy-path warning, NaN-raise with `(regime, age)` in the
message, off-level skip, and per-period debug stats.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each `_DiagnosticRow` previously held the active-period
`state_action_space`, the rolling `next_regime_to_V_arr`, the regime's
flat params, and a `compute_intermediates` closure (which itself
captured the state_action_space). At production grid sizes — 50+
periods × ~6 active regimes — the accumulated references pin every
period's full-shape V mapping in device memory, OOMing the V100 16 GB
mid-loop on `block_until_ready` (the next allocation that has nowhere
to go).

The streaming NaN/Inf reduction landed earlier addressed only the
per-period reduction buffers; the row-level retention is the larger
leak. Strip `_DiagnosticRow` to the three Python scalars actually
needed for failure-path localisation (`regime_name`, `period`, `age`)
and reconstruct the heavy bits from `solution`, `internal_regimes`,
and `internal_params` inside `_raise_at`. The reconstruction mirrors
the loop's roll-forward semantics: for each regime, take the smallest
later period in `solution` where the regime was active, falling back
to a zeros template — the same value the rolling
`next_regime_to_V_arr` slot held during the live dispatch.

Also lock the row's shape via a structural test so future changes
that re-introduce device-backed fields fail loudly in CI rather than
silently regressing OOM behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two changes targeting the NaN-in-V failure path:

1. Fail-fast at age boundary. Adds a per-period
   `running_any_nan.item()` host transfer right after the existing
   `block_until_ready`. On True, the loop breaks out and the existing
   post-loop emitter raises immediately. Cost: one scalar bool transfer
   per period — negligible next to `max_Q_over_a`. Without this,
   backward induction would finish the entire age range (potentially
   ~2h on production grids) before raising at the first-NaN row,
   leaving the user staring at an idle-looking solve.
   Inf stays non-fatal; the post-loop warning still fires for any
   period that flagged it.

2. Drop the misleading "re-solve with debug logging" suggestion from
   `validate_V`. The diagnostic [NOTE] is added inline by
   `_enrich_with_diagnostics` whenever `compute_intermediates` is
   wired up — i.e. on the default path — so suggesting a re-solve to
   "produce" diagnostics is wrong: they were already produced. Replace
   with a pointer to the [NOTE] for the per-axis breakdown plus a
   mention of `log_path=...` for snapshot persistence (the only thing
   debug-mode actually adds beyond the inline diagnostic).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When `log_path` is configured, the failure path already calls
`save_solve_snapshot(...)` (`model.py:223-230` and `:334-341`) before
re-raising — but the path it returns wasn't surfaced anywhere, so the
user saw a generic "pass `log_path=...`" hint pointing them to do
something they had already done. Capture the returned `snap_dir` and
attach it via `exc.add_note(f"Snapshot saved to {snap_dir}")`. The
note appears alongside the diagnostic-summary note that
`_enrich_with_diagnostics` adds, so the user sees both the per-axis
NaN breakdown and the exact `solve_snapshot_NNN/` directory in one
exception.

Drop the now-redundant `log_path=...` suggestion from `validate_V`'s
message. Replace with a short pointer to the [NOTE] block: when
`log_path` is set, the second note has the path; when it isn't, the
inline diagnostic still pinpoints the offending intermediate. The
debugging-guide link stays.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When the user declares the simulate batch size up front via
`Model(n_subjects=N)`, the first matching `simulate(...)` call now AOT-
compiles every unique simulate function for that shape in parallel
(`ThreadPoolExecutor` over `lower(...).compile()`), mirroring solve's
existing AOT path in `solve_brute._compile_all_functions`. Subsequent
calls with the same size hit the cache; calls with a mismatching size
warn once per size and fall back to the runtime-traced path.

Also normalises `period_to_regime_to_V_arr` at the entry of `simulate`
so every period dispatches with the same pytree (active-regime padding
with zeros). Without this the last period's empty `next_regime_to_V_arr`
breaks both the AOT pytree signature and JAX's own JIT cache reuse.

`n_subjects=None` (the default) preserves the previous lazy behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous round padded `next_regime_to_V_arr` to all 19 regime keys at
every period inside `simulate.simulate(...)`. That was a workaround for a
pytree mismatch I'd introduced on the AOT side, not a real requirement —
runtime has always passed only the active-at-P+1 regime keys (or `{}`
past the last period), and `argmax_and_max_Q_over_a` traced fine against
that sparse mapping. Padding everywhere widened the live device footprint
of every dispatch (aca-baseline benchmark went 539 MB → 1.03 GB peak GPU,
+11% execution time).

Fix: keep the runtime path sparse and have AOT compile against the same
sparse pytree per period. `_collect_unique_simulate_functions` now keys
the argmax dedup on `(func_id, active_at_next_period)` so two periods
sharing the same Q_and_F closure but seeing different active-regime sets
at P+1 each get their own compiled program. The lower-args template is
built per period from those active regimes only.

Net effect:
- Default (lazy) path: identical pytree to before this PR; the
  benchmark regression goes away.
- AOT path: same correctness, programs sized to the actual runtime
  signature, dedup still effective when consecutive periods share the
  active set.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…odel

Exercises the AOT-simulate path so the benchmark actually measures it.
The benchmark env pins aca-model by SHA. The previous SHA pre-dates
`create_benchmark_model(n_subjects=...)`, so the aca-baseline benchmark
fails at `setup_cache` with `unexpected keyword argument 'n_subjects'`.
Bump to the tip of `feature/runtime-consumption-points`.
aca-model now requires `max_consumption` on every `create_model*`
factory (no default) — pass `_MAX_CONSUMPTION=300_000.0` to
`create_benchmark_model` so the benchmark builds.
hmgaudecker and others added 16 commits May 8, 2026 11:14
…uction

ScalarFloat, ScalarInt, and ScalarBool now stand for JAX scalars only,
so downstream annotations (e.g. aca-model DAG functions) carry the
"post-cast invariant" guarantee accurately.

Changes that follow from the tightening:

- UniformContinuousGrid (LinSpacedGrid, LogSpacedGrid) and
  IrregSpacedGrid use a manual __init__ to accept Python literals at
  the user-facing API and store start/stop/points as JAX scalars at
  canonical_float_dtype(). Grid dtype is now sticky to construction
  time x64 mode.

- Coordinate helpers (linspace, logspace, get_*_coordinate,
  Grid.get_coordinate) widen each numeric slot to
  `float | ScalarFloat` / `int | ScalarInt` so they remain callable
  from setup-time Python code as well as the JIT'd DAG.

- simulate.py replaces `enumerate(ages.values)` with index-based
  iteration so `age` carries a proper JAX-scalar type; transitions.py
  follows.

- Display/diagnostic age parameters in error_handling.py and
  logging.py widen to `int | float | ScalarInt | ScalarFloat` so
  Python literals from `_DiagnosticRow` keep working.

Test changes: parametrised dtype-invariant test now constructs grids
inside the test body so the x64_disabled fixture is in effect; the
returning-int test in test_regime_state_mismatch flips to `-> int`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`linspace`, `logspace`, `get_*_coordinate` are pylcm-internal: every
production caller (Grid methods, piecewise dispatchers) hands them
JAX scalars. Drop the `float | ScalarFloat` widening on `start` /
`stop` / `value` so the helpers pin the post-cast contract.

Conversion of user input now happens once at the public-API boundary,
inside `Grid.get_coordinate`, via a small `_to_jax_scalar` helper. The
helper-direct tests in test_grid_helpers.py wrap their literals with
`jnp.asarray` to match.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`Model.__init__` lifts `fixed_params` Python scalars to JAX arrays via
the boundary dtype cast, which initialises CUDA in the parent process
when running under cuda12. ASV forks the benchmark worker from that
parent; the inherited CUDA context is unusable in the child and
surfaces as `CUDA_ERROR_NOT_INITIALIZED` on the first device op.

Wrap `_build()` in `jax.default_device(cpu)` so all setup-time array
creations stay on CPU. The worker process initialises CUDA freshly
when `simulate(...)` runs in `setup`/method bodies; JAX moves the
deserialised arrays to GPU on demand.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…sult

When `Model(n_subjects=N)` triggers an AOT compile, every
`InternalRegime.simulate_functions` field carries a `jax.stages.Compiled`
that holds an unpicklable `LoadedExecutable`. The snapshot already
side-loads the V-array via HDF5; widen the strip pass to overwrite
`SimulationResult._internal_regimes` with `model.internal_regimes`
(the lazy regimes — same metadata, JIT'd `PjitFunction`s pickle cleanly,
which is why `model.pkl` survives the same round-trip).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ASV's forkserver runs `preimport` to discover benchmarks across every
`bench_*.py` module before forking workers. Importing JAX at module
top loads the multithreaded XLA backend into the forkserver; every
subsequent `os.fork()` (for any benchmark, not just this one) inherits
a corrupted CUDA context and the first device op in the worker aborts
with `CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of
the forkserver and confine it to the worker process.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Continues the dtype-barrier work by promoting internal scalar metadata
to JAX-typed forms wherever it lives strictly inside pylcm:

- `UniformContinuousGrid.n_points` and `Piece.n_points` are stored as
  `jnp.int32` JAX scalars, converted from the Python literals at
  construction. `_init_uniform_grid` casts `start` / `stop` /
  `n_points` at the boundary before validation; the validator can then
  assume strict `ScalarFloat` / `ScalarInt` arguments and only check
  value invariants. Coordinate helpers (`linspace`, `logspace`,
  `get_*_coordinate`) tighten `n_points` to `ScalarInt` so the
  conversion happens once at the boundary instead of at every call.
- `Grid.get_coordinate` reverts to `ScalarFloat | Array` (no Python
  float). The single production caller in `regime_building/V.py`
  always passes a JAX array; tests that called the helpers with
  Python literals wrap them with `jnp.asarray` / `jnp.int32`.
- `Period` aliases `ScalarInt` and `Age` aliases `ScalarInt | ScalarFloat`
  for the JIT-internal scalar contexts. `AgeGrid.period_to_age` and
  `age_to_period` use plain `int | float` directly since they are
  user-facing API methods returning Python values.
- `_simulate_regime_in_period` and the `transitions.py` helpers now take
  `period: ScalarInt`. The simulation loop derives `period = jnp.int32
  (period_idx)` once per iteration and passes it through; dict-key
  lookups (`argmax_and_max_Q_over_a[period_idx]`,
  `period_to_regime_to_V_arr.get(period_idx + 1)`) keep using the
  Python int.
- `FlatRegimeParams` tightens to `MappingProxyType[str, Array]` —
  post-whitelist every leaf is a JAX array, the prior `bool | float |
  Array` union was stale.
- `safe_to_int32` renamed to `safe_to_int_dtype` to mirror
  `safe_to_float_dtype`.
- `_strip_V_arr_from_result` made fully kw-only.
- `pyproject.toml` ignores `ARG001` for
  `tests/test_float_dtype_invariants.py` so per-test
  `# noqa: ARG001` comments drop out and signatures collapse to a
  single line.
- `Piece` becomes `init=False` with a manual `__init__` that lifts
  `n_points` to `jnp.int32`, mirroring `UniformContinuousGrid`.

Test-side fallout addressed in the same commit: literals wrapped with
`jnp.asarray` / `jnp.int32` where helpers tightened, redundant
`# ty: ignore` comments dropped, and three "validator rejects
non-numeric" tests reframed to assert the boundary cast catches them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`jnp.linspace`/`jnp.logspace`'s `num` parameter is annotated `int` in
JAX's stubs but accepts `jnp.int32` JAX scalars in eager mode (verified
on cuda12). Pass `n_points: ScalarInt` through directly and silence the
type-check mismatch at the single call site rather than materialising
the JAX scalar to a Python int.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the Python `sum(generator, start=jnp.int32(0))` with a single
`_piece_n_points.sum()` reduction. The cached `Int1D` is already
populated by `_init_piecewise_grid_cache`, the property is read after
`__post_init__`, and the result is the same `ScalarInt`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pull in the consumption-grid pinning, borrowing-constraint kink fix, and
precision-workaround cleanups so the GPU benchmark CI runs the
benchmark-aca-baseline kernel that aca-dev currently tracks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The period_idx / period split was noisy: every loop iteration computed
both a Python int (for dict-key indexing and `period in active_periods`)
and a JAX scalar (for the JIT'd compute call). Drop the JAX-scalar
shadow; iterate `for period, age in enumerate(ages.values)` once.
`_simulate_regime_in_period(period: int)` keeps the integer through
dict lookups and casts to `jnp.int32(period)` only at the
`argmax_and_max_Q_over_a` / next-state JIT boundaries. Same pattern
for transitions.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When `Model(n_subjects=N)` is set, simulate-side XLA compilation used
to run lazily on the first matching `simulate(...)` call — strictly
after `solve(...)` returned. On production aca-baseline that adds
several minutes to the end-to-end wall clock for nothing: solve is
GPU-bound, simulate compile is CPU-bound XLA work, so they overlap
trivially.

Add `_maybe_start_simulate_compile_async` and call it from `solve(...)`
right after parameters are processed. It spawns a single-worker
`ThreadPoolExecutor` that runs `compile_all_simulate_functions` in the
background and parks the result on `_simulate_compile_future`.
`_resolve_simulate_internal_regimes` awaits the future before
populating the cache, so the lazy fallback path (no `solve` call,
direct `simulate(...)`) still works.

`__getstate__` / `__setstate__` drop the future on the way out and
reset to `None` on the way in — `concurrent.futures.Future` is tied to
its originating thread pool and can't survive a process boundary.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pulls in the aca-model CI workflow's matching pylcm pin so the GPU
benchmark CI runs the same aca-model rev that aca-dev now tracks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
solve() no longer touches simulate-side compile state. simulate() is the
sole driver: spawns the AOT compile in a background thread when
n_subjects is set and the batch shape matches, then runs solve (if
period_to_regime_to_V_arr is None) and awaits the future at the
state-action-space dispatch point. Both public methods share an internal
_solve_compiled() body for the snapshot/error handling.

Drops _simulate_compile_future from instance state — the future lives in
a local variable on the simulate() stack, so there's no per-process
state to gate against. The lock keeps protecting _simulate_compile_cache
and _warned_n_subjects; the rest of the "maybe spawn" logic collapses
into a single inline check at the simulate() call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move the ARG001 ignore for the x64_disabled / x64_enabled fixture
pattern into pyproject.toml's per-file-ignores for test_dtypes.py and
test_float_dtype_invariants.py, then drop the per-call noqa comments
and the now-redundant -> None return annotations (tests/* already
ignores ANN). Single-arg signatures collapse to one line; longer ones
stay wrapped, but without the trailing comma noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`period=1, age=1.0, **flat_regime_params={...float...}` was suppressed
with `# ty: ignore[invalid-argument-type]` to keep the call site
short. Once `ScalarInt` / `ScalarFloat` tightened to JAX-only, the
fix is to pass `jnp.int32(1)` / `jnp.asarray(1.0)` (and to wrap the
float param leaves in `jnp.asarray`). The ignore comments come out
and the call site genuinely type-checks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker marked this pull request as ready for review May 9, 2026 11:29
@hmgaudecker hmgaudecker requested review from mj023 and timmens May 9, 2026 11:29
hmgaudecker and others added 2 commits May 9, 2026 15:28
`SimulationResult.to_pickle()` (and any cloudpickle.dumps on the
result) hit `cannot pickle 'jaxlib._jax.LoadedExecutable'` when the
result carried the AOT-compiled `internal_regimes`. The compiled
callables (`argmax_and_max_Q_over_a`, `next_state`,
`compute_regime_transition_probs`) wrap a `LoadedExecutable` that
can't survive a process boundary.

`to_dataframe` only reads `simulate_functions.functions /
constraints / transitions / stochastic_transition_names` — none of
which the AOT pass replaces. So after `simulate(...)` runs, the
result has no use for the compiled callables: `model.simulate()`
swaps them out for the lazy `self.internal_regimes` before
returning.

Add a TDD test that round-trips the result through cloudpickle
under `n_subjects` matching, which is the failure mode pytask hit
on HPC.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`simulate(...)` previously kicked off `compile_all_simulate_functions`
in a single-thread background executor and ran solve concurrently;
`_resolve_simulate_internal_regimes` then awaited the future at the
state-action-space dispatch point. With realistic worker counts the
parallel XLA compile pool stayed busy through a substantial chunk of
the backward-induction loop, contending for CPU and XLA front-end
resources and stretching mid-loop ages by an order of magnitude.

Drop the future / executor entirely. simulate() now calls
compile_all_simulate_functions inline before _solve_compiled, so the
entire AOT compile (including its own internal worker pool) finishes
before backward induction starts. Same total compile work; predictable
timing; lower transient host-RAM peak because the AOT pool's
intermediate Lowered objects are released before solve allocates its
per-period V buffers.

_resolve_simulate_internal_regimes loses its compile_future parameter
and only consults the cache. _spawn_simulate_compile is gone, as are
the `Future` and `ThreadPoolExecutor` imports.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@mj023 mj023 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, definitely better to use the Jax Types everywhere.

Comment thread src/lcm/ages.py

from lcm.exceptions import GridInitializationError, format_messages
from lcm.typing import Age, Float1D, Int1D
from lcm.typing import Float1D, Int1D
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use the Age Type anymore?

JAX silently truncates `jnp.int64` / `jnp.float64` requests under
`jax_enable_x64=False` and emits a `UserWarning`. The default test
config (`filterwarnings = []`) let those warnings pass — a stray
`int64` literal in src/ would slip through CI as a warning the
operator would have to spot by eye.

Switch the filter to `error:Explicitly requested dtype.*:UserWarning`.
Combined with the existing `--precision=32` job (`tests-32bit`),
every wide-dtype literal in src/ now fails the suite.

The three dtype-invariant test modules (`test_int_dtype_invariants`,
`test_float_dtype_invariants`, `test_dtypes`) opt back to the warning
default via a module-level `pytestmark` — they exist to *exercise* the
cast at the barrier and legitimately pass `int64` / `float64` inputs.

Add `tests/test_explicit_dtype_filter.py` with two tests confirming
the filter is in effect: each requests a wide dtype and asserts the
warning surfaces as `UserWarning`. Addresses the review on #340
without the false-positive surface of a literal-string grep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Base automatically changed from feat/simulate-aot-n-subjects to main May 11, 2026 05:43
hmgaudecker and others added 2 commits May 11, 2026 07:46
The squash-merge of #340 onto main carried a small int-cast loop
inside `broadcast_to_template` that duplicates work already done by
`cast_params_to_canonical_dtypes` (the float-side reshuffle separated
broadcast and cast into two passes). Drop it.

Bump the `benchmarks` feature's aca-model rev to 9ac2043 so this
branch carries the same pin PR #347 was opened for; #347 can close.

Lockfile updated to track the merged pylcm HEAD and the new
aca-model rev.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 99a5e31 into main May 11, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the feat/canonical-float-dtype branch May 11, 2026 07:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants