Skip to content

Adopt beartype runtime type checking across lcm#355

Merged
hmgaudecker merged 104 commits into
mainfrom
feat/beartype-perimeter
May 18, 2026
Merged

Adopt beartype runtime type checking across lcm#355
hmgaudecker merged 104 commits into
mainfrom
feat/beartype-perimeter

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 13, 2026

Closes #176.

Runtime type checking via beartype across the entire lcm package,
plus the annotation cleanup, type-discipline work, and bug fixes the
checking surfaced. Originally split as #355 / #356 / #357 (perimeter +
scoped claw + annotation cleanup) — collapsed into one PR because the
follow-on work was a single arc.

Layers

Package-wide claw with explicit user-boundary decorators

src/lcm/__init__.py registers a single
beartype_package("lcm", INTERNAL_CONF) before any submodule loads.
Every module loads with AST-rewritten runtime type checks; violations
default to BeartypeCallHintViolation (an internal pylcm bug).

User-facing constructors stack their own @beartype(conf=PROJECT_CONF)
on top: Regime / MarkovTransition / Model, every grid and shock
class, AgeGrid, DiscreteGrid, the categorical(...) factory, the
as_leaf factory. The explicit decorator wins, so the documented
project exception (GridInitializationError, ModelInitializationError,
InvalidParamsError, RegimeInitializationError,
CategoricalDefinitionError) keeps surfacing at the user boundary.

Type-system: user-input vs. canonical

The boundary-vs-internal split is reified in lcm.typing:

  • _UserParamsLeaf / _ParamsLeaf, with UserMappingLeaf /
    MappingLeaf and UserSequenceLeaf / SequenceLeaf as the
    corresponding leaf containers. The wider user variant accepts Python
    scalars, numpy arrays, pd.Series; the narrow canonical variant
    carries only canonical-dtype JAX arrays.
    cast_params_to_canonical_dtypes is typed UserParams → Params.
  • initial_conditions: Mapping[StateName | Literal["regime"], Array | np.ndarray]
    at boundary entries (Model.simulate, canonicalize_initial_conditions);
    Mapping[..., FloatND | IntND] downstream.
  • State/action-keyed mappings (state_action_values, grid mappings,
    action_grids) use the StateOrActionName / ActionName aliases.

Canonicalization at every user-supplied boundary

_ShockGrid.params returns MappingProxyType[str, ScalarFloat | ScalarInt]
— Python scalars on dataclass fields are cast to 0-d JAX scalars on
access. _params_to_jax and the per-call casts in
regime_building/processing.py / regime_building/next_state.py /
interfaces.py are gone.

Bugs surfaced by the claw

  • _validate_regime_transition_single closed a Python-int period
    over jax.vmap; under x64 it traced as int64 and broke the
    aca-model claw on income()'s Period hint at the cross-package
    boundary. Fixed by jnp.int32(period) before the closure.
  • AVERAGE_CONSUMPTION = 20000 in aca-data flowed in as int32 under
    the GPU benchmark and violated utility_scale_factor's ScalarFloat
    hint. Fixed at the source and in the frozen benchmark fixture.
  • _default_H's discount_factor annotation kept as FloatND; test
    fixtures pass floats consistently.

Annotation cleanup

Drift sweep across the now-clawed modules: int / float / bare
Mapping / bare Callable / wrong-protocol / forward-ref annotations
tightened to the appropriate lcm.typing alias or Protocol.
__annotate__ is stripped from functools.wraps (via a reduced
_WRAPPER_ASSIGNMENTS_NO_ANNOTATIONS) so generic
(*args: Any, **kwargs: Any) wrappers in lcm.utils.functools don't
inherit user-model annotations the claw would then mistakenly enforce.

@categorical(ordered=...) now requires every field to be annotated
as ScalarInt and raises CategoricalDefinitionError at decoration
time on anything else.

Includes

Tests

pixi run -e tests-cpu pytest -n 4: 990 passed, 10 skipped.
pixi run -e type-checking ty: clean. prek run --all-files: clean.
GPU benchmark-pr + platform main CI workflows: green.

New regression coverage: period reaches the transition function as
int32; LinSpacedGrid with a bad arg still raises
GridInitializationError under the unified claw (mirroring the
existing checks for Regime and Model); one
test_claw_checks_lcm_* per newly-clawed module.

mj023 and others added 29 commits May 4, 2026 14:06
Add a `distributed=True` flag on `DiscreteGrid` to shard the grid
across JAX devices, thread the distribution pattern through
`solve_brute._get_regime_V_shapes_and_shardings`, and validate the
device-count match at runtime via a new check in
`InternalRegime.state_action_space`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the flat `{regime}__{state}` qname-joined state container with
nested `MappingProxyType[RegimeName, MappingProxyType[StateName, Array]]`.
Pylcm code no longer constructs or parses the `__` separator on the
simulation read/write path.

Type aliases (lcm/typing.py):
- `RegimeStates = MappingProxyType[StateName, Array]`
- `StatesPerRegime = MappingProxyType[RegimeName, RegimeStates]`

`_update_states_for_subjects` becomes `_advance_states_for_subjects`,
takes paired `current_states_per_regime` / `next_states_per_regime`
arguments, and is a pure StatesPerRegime-in/out merge. The `next_`
prefix strip moves upstream into `calculate_next_states`, immediately
after `next_state_vmapped(...)`, so both arguments share inner-key
naming.

Touchpoints:
- `simulation/initial_conditions.py:build_initial_states` returns
  nested StatesPerRegime; per-regime inner mappings replace the flat
  `regime__state` dict.
- `simulation/transitions.py:60` (`create_regime_state_action_space`)
  accesses `current_states_per_regime[regime_name][sn]`.
- `simulation/transitions.py:125-142` (`calculate_next_states`) strips
  the `next_` prefix and calls the renamed function.
- `simulation/transitions.py:262-310` (`_advance_states_for_subjects`):
  nested merge with no string concat and no removeprefix.
- `simulation/simulate.py:294-296` filter collapses to
  `states[regime_name]`.

Tracks pylcm#343 Phase 1; Phase 2 (nest the DAG function dict and
rewrite three remaining `tree_path_from_qname` introspection sites)
lands in a follow-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Removes the flat-iterate + `tree_path_from_qname` decode pattern from
the three simulation introspection sites (`transitions.py:111`,
`compile.py:381`, `result.py:436`) — iterate the nested
`transitions: Mapping[RegimeName, Mapping[TransitionFunctionName, ...]]`
directly and assemble qnames at the boundary via `qname_from_tree_path`.

Per-leaf stochastic factories in `next_state.py` now take `target` and
`next_state_name` as separate args:
- `_create_discrete_stochastic_next_func` builds the wrapper-kwarg
  qname (`weight_<target>__<next>` / `key_<target>__<next>`) locally
  instead of receiving the pre-joined `name`.
- `_create_continuous_stochastic_next_func` takes the nested
  `all_grids` and the structured `(target, next_state_name)`, dropping
  the `name.split("next_")[1]` / `name.replace("next_", "")` parse
  used to recover components from the qname.

`processing.py:599-611`, `next_state.py:143`, and `Q_and_F.py:516`
replace `f"{regime}__next_{shock}"` / `f"weight_{regime}__{key}"`
f-string concat with `qname_from_tree_path` calls. The wrapper-kwarg
strings still exist inside `dags`'s qname encoding (that's the public
DAG-arg convention `dags` ships), but pylcm code no longer produces
or parses them outside the introduction boundary.

DAG topology unchanged: per-(target, shock) wrappers stay independent
nodes, `dags.concatenate_functions` keeps its pruning behaviour.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop redundant `current_` prefix: `current_states_per_regime` →
  `states_per_regime` throughout simulation transitions.
- Specialise `create_regime_state_action_space` to take `regime_states:
  RegimeStates` instead of the full per-regime carrier; caller indexes by
  regime name.
- Annotate state-keyed inner mappings with `StateName` alias rather than
  bare `str` (build_initial_states locals, _advance_states_for_subjects,
  initial_states/subject_states helper params).
- Rename misleading test locals: `flat`/`nested`/`updated` →
  `states_per_regime`/`next_states`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Revert encode-only `qname_from_tree_path` calls back to plain f-strings
  where the result is consumed by an outer f-string or the arg list itself
  contains f-strings (`next_state.py:142`, `Q_and_F.py:517`,
  `processing.py:599-611`, `result.py:432`).
- Keep `qname_from_tree_path` where it stands alone in nested iteration
  (`transitions.py:109`, `compile.py:376`) or as an explicit assignment in
  the per-leaf factories. Rename the assigned local from `name` to `qname`
  (and the matching param in `_create_ar1_next_func` / `_create_iid_next_func`).
- Rename `prev_state_name` → `state_name` in the continuous-shock factory
  family. It's the state-name without the `next_` prefix; "prev" was relative
  to the transition output and confused the lookup against the carrier.
- Sharpen the `labels` docstring in `_create_discrete_stochastic_next_func`:
  category codes the discrete state can take, drawn via `jax.random.choice`
  weighted by `weight_<qname>`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Internal `InternalRegime.variable_info` flips from `pd.DataFrame` with
five boolean columns (`is_state`, `is_action`, `is_continuous`,
`is_discrete`, `is_shock`) to `MappingProxyType[StateOrActionName,
VariableInfo]` where `VariableInfo` is a frozen dataclass with three
fields:

- `kind: Literal["state", "action"]`
- `topology: Literal["continuous", "discrete"]`
- `is_shock: bool`

`VariableInfo` + `VariableInfoMapping` alias live in `interfaces.py`
next to the other internal data types. The constructor in
`regime_building/variable_info.py` is rewritten in pure Python, with
shock variables tagged `topology="discrete"` (matching the existing
`is_continuous = ContinuousGrid and not _ShockGrid` semantics — shock
grids approximate a continuous random variable but are processed by
discrete-sweep machinery).

ty gains visibility into every `variable_info` access. The 15 call
sites that previously used `.query("is_state")` etc. now use
comprehensions like `[k for k, v in vi.items() if v.kind == "state"]`.

Closes part 1 of #176.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Wrap the per-regime states+actions in a frozen `Variables(Mapping[...,
VariableInfo])` class with pre-computed name-tuple views (state_names,
shock_names, discrete_state_names, ...). Replaces the 16 sites that did
`[name for name, info in variable_info.items() if info.kind == "state"]`
with single attribute accesses.

`InternalRegime.variable_info` becomes `InternalRegime.variables`.
`get_variable_info(regime)` becomes the classmethod
`Variables.from_regime(regime)`. The intermediate file
`regime_building/variable_info.py` is gone; `get_grids` moves to
`src/lcm/variables.py` next to the class.
The earlier rewrite kept isinstance checks under
`name in state_names_set`, mirroring the pre-Variables logic. The
named views already filter on `kind == "state"` plus the matching
topology, so the isinstance branches become redundant — shocks live
in `discrete_state_names` (topology `"discrete"`), non-shock
ContinuousGrid states in `continuous_state_names`.
- transitions.py / result.py: inline `state_names` / `action_names` /
  `relevant_state_names` rebinds — `internal_regime.variables.state_names`
  is a cached attribute, so the rebind buys nothing.
- error_handling.py: same for `target_state_names`.
- initial_conditions.py: drop `_get_regime_state_names` — three callers,
  each one line shorter with the attribute access inlined.
- variables.py: collapse two `def get_grids(\n    regime: Regime,\n)` /
  `_raw_variable_info` signatures to one line; both fit comfortably.
`src/lcm/_beartype_conf.py` exposes five module-level `BeartypeConf`
instances, each configured to raise the existing project exception
class on parameter-type violations. Subsequent commits decorate
user-facing constructors and runtime methods with these configs.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`@beartype(conf=REGIME_CONF)` on `Regime` and `MarkovTransition`
catches parameter-type violations at construction time and raises
`RegimeInitializationError` (via `BeartypeConf(violation_param_type=...)`)
preserving the existing exception contract.

- `validate_attribute_types` in `regime_building/validation.py` is
  replaced by `validate_mapping_contents`, a slim aggregator that
  covers what beartype can't deep-check: exhaustive iteration of
  `Mapping[..., Callable]` and `Mapping[..., Protocol]` value
  types. Beartype's `O(n)` strategy still only samples Mapping
  entries when the value is a Protocol/Callable.
- `MarkovTransition.__post_init__`'s manual `callable(self.func)`
  check is dropped — beartype covers it.
- Protocols in `lcm.typing` (`UserFunction`, `ActiveFunction`, etc.)
  are marked `@runtime_checkable` so beartype's isinstance checks
  succeed.
- Strategy bumped to `BeartypeStrategy.On` for linear-time
  container validation (cheap at construction sites).
- 8 tests in `test_model.py` / `test_regime.py` updated from
  message-text matching to parameter-name matching.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…l_inputs

`@beartype(conf=MODEL_CONF)` on `Model.__init__` raises
`ModelInitializationError` on parameter-type violations.

- `validate_model_inputs` drops the `isinstance(n_periods, int)`
  check (n_periods is structurally an `int` from
  `ages.n_periods`) and the "All items in regimes must be
  instances of lcm.Regime" early exit (beartype on
  `regimes: Mapping[RegimeName, Regime]` covers it). The value
  and cross-field checks below (n_periods >= 2, regime name
  separator, terminal/non-terminal counts, etc.) remain.
- `_ParamsLeaf` widened to include `int` — beartype caught an
  annotation drift where the runtime accepted ints but the type
  declared only `bool | float | ...`.
- `test_n_subjects_validation_rejects_non_int` expects
  `ModelInitializationError` instead of `TypeError`, aligning with
  the rest of `Model.__init__`'s error contract.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`@beartype_init(GRID_CONF)` on 6 grids and 7 shock grids checks
parameter types at construction time and raises
`GridInitializationError` while leaving each class's other methods
unwrapped. Bare `@beartype` on a class wraps every method, which
surfaces annotation drift in helpers like
`compute_gridpoints(**kwargs: float)` where runtime kwargs are JAX
arrays.

- `_beartype_conf.beartype_init` is the class decorator that only
  wraps `__init__`.
- Decorated classes: `DiscreteGrid`, `LinSpacedGrid`, `LogSpacedGrid`,
  `IrregSpacedGrid`, `PiecewiseLinSpacedGrid`, `PiecewiseLogSpacedGrid`;
  `Uniform`, `Normal`, `LogNormal`, `NormalMixture`, `Tauchen`,
  `Rouwenhorst`, `TauchenNormalMixture`. `AgeGrid.__init__` (a plain
  method, not a dataclass) gets a direct `@beartype(conf=GRID_CONF)`.
- `categorical` decorator factory gets `@beartype(conf=CATEGORICAL_CONF)`.
- `BeartypeConf` flipped to `is_pep484_tower=True` so `int` satisfies
  `float`-typed parameters, matching Python's numeric tower and
  ruff's PYI041 — without this, `LinSpacedGrid(start=1, ...)`
  failed because `1: int` is not `float` at the type level.
- Updated 6 test sites in `tests/test_grids.py` from message-text
  matching to parameter-name matching, plus type from `TypeError`
  to `GridInitializationError` where applicable.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`@beartype(conf=PARAMS_CONF)` on the two runtime entry points
catches parameter-type violations (bad `params` structure, wrong
`initial_conditions` types, malformed `period_to_regime_to_V_arr`)
and raises `InvalidParamsError`. Per-call cost is invisible at the
construction/run cadence these methods see.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 13, 2026

hmgaudecker and others added 14 commits May 14, 2026 23:32
Extends the beartype claw to `lcm.utils.error_handling` with
`INTERNAL_CONF`. Narrows `validate_V` / `_enrich_with_diagnostics`
parameters from bare `MappingProxyType` / `Mapping` to
`MappingProxyType[RegimeName, FloatND]` and `FlatRegimeParams`, and
switches the `Model` annotation on `validate_transition_probs` to the
fully-qualified `lcm.model.Model` forward reference so the claw
resolves it at first call rather than tripping the model.py import
cycle at module-init time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Extends the beartype claw to `lcm.state_action_space` with
`INTERNAL_CONF`. The module is already fully annotated with narrowed
`lcm.typing` aliases, so no annotation changes are needed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Extends the beartype claw to `lcm.interfaces` with `INTERNAL_CONF`.
The `SolveFunctions` / `SimulateFunctions` / `InternalRegime`
dataclasses store dags-wrapped callables in Protocol-typed fields;
the claw's `__init__` checks structurally accept those callables, so
the regime_building / solution / simulation suites stay green.
Narrows `InternalRegime.name` from bare `str` to `RegimeName`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Extends the beartype claw to `lcm.regime` with `INTERNAL_CONF`. The
explicit `@beartype(conf=REGIME_CONF)` decorator on the `Regime` and
`MarkovTransition` constructors still wins, so construction-time type
violations keep surfacing as `RegimeInitializationError`.

`_default_H` receives state/value arrays whose dtype follows user
input — a `discount_factor` supplied as `1` arrives as an int32 array,
as `0.95` a float32 array — so its parameters are annotated `NumericND`
(a new `FloatND | IntND` alias in `lcm.typing`). The named alias keeps
`dags.tree`'s params-template extraction printing `NumericND` rather
than a bare `Union`. The `_IdentityTransition.__call__` discrete test
now passes an int32 array, matching the `DiscreteState` contract the
claw enforces.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Extends the beartype claw to `lcm.model` with `INTERNAL_CONF`. The
explicit `@beartype` decorators on `Model.__init__` / `solve` /
`simulate` still win, so construction- and call-time type violations
keep surfacing as `ModelInitializationError` / `InvalidParamsError`.
Swaps the inline `MappingProxyType[int, MappingProxyType[RegimeName,
FloatND]]` annotations for the `PeriodToRegimeToVArr` alias.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`_default_H`'s `discount_factor` should always receive a float, never
an int — the prior commit smuggled `FloatND | IntND` via a new
`NumericND` alias to accommodate test fixtures passing `1`/`0`, which
was the wrong way round. Pass floats from the tests, keep `_default_H`
strictly `FloatND`, drop the alias.

- `_default_H` reverted to `FloatND` on all three params + return
- `NumericND` removed from `lcm.typing`
- `tests/test_models/shock_grids.py`: `"discount_factor": 1` → `1.0` (×2)
- `tests/test_solution_on_toy_model_{deterministic,stochastic}.py`:
  parametrize `[0, 0.5, 0.9, 1.0]` → `[0.0, 0.5, 0.9, 1.0]` (×4)
- `tests/simulation/test_simulate.py`: `discount_factor=1, interest_rate=0`
  → `1.0, 0.0`
- `tests/regime_building/test_create_regime_params_template.py`:
  `"NumericND"` expectation → `"FloatND"` (×3)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`_ShockGrid.params` now returns `MappingProxyType[str, ScalarFloat |
ScalarInt]` — Python `bool`/`int`/`float` dataclass fields are cast to
0-d JAX scalars on access. Every downstream consumer
(`compute_gridpoints`, `compute_transition_probs`,
`_create_ar1_next_func`, `_create_iid_next_func`,
`StateActionSpace.replace`, the regime-building runtime closures)
already required `ScalarFloat | ScalarInt`-valued mappings and was
doing the cast itself via `_params_to_jax`. With the cast hoisted to
the boundary, the helper is no longer needed.

- `_ShockGrid.params` returns canonical 0-d scalars
- `_params_to_jax` deleted (orphaned)
- `weights_func_runtime` / `next_stochastic_state` closures: pass
  `shock_kw` through directly; annotations tightened to
  `dict[str, FloatND | IntND]`; `with_signature` runtime-param strings
  switched from `"float"` to `"FloatND"`
- `StateActionSpace.replace`: `shock_kw` retyped `ScalarFloat |
  ScalarInt`, redundant `_params_to_jax` call dropped

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Add `UserMappingLeaf` / `UserSequenceLeaf` as the boundary leaf types
accepted by `Model.__init__` / `Model.solve` / `Model.simulate`, and
narrow `MappingLeaf` / `SequenceLeaf` to canonical-only subclasses
emitted by `cast_params_to_canonical_dtypes`. The static type system
now distinguishes user input from post-canonicalization values; the
runtime constructor signatures keep `Mapping[str, Any]` /
`Sequence[Any]` so beartype doesn't fire on user-supplied scalars.

- `_UserParamsLeaf` covers Python scalars, numpy/pandas, JAX arrays,
  and `UserMappingLeaf` / `UserSequenceLeaf`; `UserParams` is the
  boundary `Mapping` type alias.
- `_ParamsLeaf` narrows to `FloatND | IntND | BoolND | MappingLeaf |
  SequenceLeaf`; `Params` is the post-canonicalization `Mapping` type.
- Both leaf variants registered as separate JAX pytrees so
  `jax.tree.map` round-trips each in its own type.
- `_cast_leaves_to_canonical_dtype` accepts the `User...Leaf` base
  classes and always emits the canonical narrow subclasses.
- `convert_series_in_params` preserves the boundary variant — it runs
  between broadcast and canonicalization.
- `_make_immutable` / `_make_mutable` / `has_series` switch their
  `isinstance` checks to the base classes (covering both variants).
- `as_leaf` returns the boundary `User...Leaf` variant.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…cols

`FlatRegimeParams` now includes `MappingLeaf` and `SequenceLeaf` in its
value union, matching what `cast_params_to_canonical_dtypes` actually
emits and surfacing the leaf type to every downstream consumer of
`internal_params`.

- `InternalUserFunction`, `RegimeTransitionFunction`,
  `VmappedRegimeTransitionFunction`, and `NextStateSimulationFunction`
  accept `MappingLeaf | SequenceLeaf` in `*args` / `**kwargs` — every
  call site already passes them via `**regime_params`.
- `InternalRegime.state_action_space` narrows `_ParamsLeaf` to `Array`
  / `ScalarFloat | ScalarInt` via `cast` at the two runtime-grid /
  shock-grid substitution points (the only two slots where
  `all_params[...]` is statically wider than the value's known shape).
- Tests touching `internal_params["..."]["..."].dtype` / `.shape` (or
  passing the leaf to `float()` / `int()` / `assert_allclose`) gain
  targeted `# ty: ignore` directives — runtime values are arrays at
  those sites, but the static type is the wider canonical-leaf union.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Boundary inputs (`Model.simulate`, `canonicalize_initial_conditions`):
  `Mapping[StateName | Literal["regime"], Array | np.ndarray]` —
  honest user-facing type, accepts both JAX and numpy arrays.
- Post-canonicalization signatures (`simulate`, `validate_initial_conditions`,
  `build_initial_states` and the other helpers, `pandas_utils`):
  `Mapping[StateName | Literal["regime"], FloatND | IntND]` or
  `Mapping[StateName, FloatND | IntND]` (no regime key) — narrow canonical
  dtypes flow downstream.
- Persistence snapshots: `Mapping[StateName | Literal["regime"], Array]` —
  neutral storage form.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…mappings

`MappingProxyType[str, ...]` / `dict[str, ...]` at sites whose keys are
state and/or action names: tighten the key type to the existing
`StateOrActionName` alias (state-or-action) or `ActionName` (action-only).

- `validate_regime_transition_probs.state_action_values` (×2 sigs):
  → `MappingProxyType[StateOrActionName, FloatND | IntND] | None`
- `_validate_regime_transition_single` local `grids` / `point`:
  → `dict[StateOrActionName, FloatND | IntND]`
- `_lookup_values_from_indices.grids` + return:
  → `MappingProxyType[StateOrActionName, FloatND | IntND]`
- `build_initial_states` local `action_grids`:
  → `dict[ActionName, FloatND | IntND]`

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Decorate every public grid and shock constructor with
@beartype(conf=GRID_CONF) so a wrong-typed argument always raises
GridInitializationError, independent of which beartype claw covers the
module. Prepares the perimeter for collapsing the per-area claws into
a single lcm-package claw without losing the project-specific
exception at user-facing construction sites.

Affected types:
- LinSpacedGrid, LogSpacedGrid (via UniformContinuousGrid.__init__ and
  LogSpacedGrid.__init__)
- IrregSpacedGrid
- Piece, PiecewiseLinSpacedGrid, PiecewiseLogSpacedGrid
- DiscreteGrid
- Uniform, Normal, LogNormal, NormalMixture
- Tauchen, Rouwenhorst, TauchenNormalMixture

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Replace the eleven per-area `beartype_package(...)` calls with one
`beartype_package("lcm", conf=INTERNAL_CONF)`. The project-specific
exception mapping survives at every user-facing constructor through
the explicit `@beartype(conf=...)` decorators (Model, Regime,
MarkovTransition, every grid and shock, `@categorical`, `as_leaf`);
internal helpers raise beartype's own `BeartypeCallHintViolation`,
matching how the previous `INTERNAL_CONF` packages already behaved.

Other drift surfaced by extending the claw to previously uncovered
modules and resolved inline:

- `lcm.variables` and `lcm.persistence` carry `TYPE_CHECKING`-only
  forward references (`Regime`, `Model`, `SimulationResult`) to break
  import cycles. Inject the resolved names into both modules in
  `lcm/__init__.py` after the cycle settles so beartype can walk the
  references at call time.
- `lcm.utils.dispatchers._base_productmap_batched.batched_vmap` and its
  inner closure now annotate kwargs/return as `Any`. The wrapper
  composes arbitrary user functions whose value-pytrees include
  `MappingProxyType` containers; constraining to `FloatND | IntND |
  BoolND` was too narrow.
- `lcm.utils.functools.{allow_args,allow_only_kwargs}` strip
  `__annotate__` from `functools.wraps` via a reduced
  `_WRAPPER_ASSIGNMENTS_NO_ANNOTATIONS` tuple. The `(*args: Any,
  **kwargs: Any)` wrappers must not inherit the wrapped function's
  per-parameter annotations, otherwise beartype enforces user-model
  types (e.g. `Int1D` on `health`) on a generic forwarding wrapper.
- `lcm.params.as_leaf` carries explicit `@beartype(conf=PARAMS_CONF)`
  so a non-Mapping / non-Sequence argument keeps raising
  `InvalidParamsError` rather than `BeartypeCallHintViolation`.

Add a regression test on `LinSpacedGrid` mirroring the existing
`Regime` / `Model` checks: an explicitly-decorated user-facing
constructor must keep raising its project exception even though the
package claw covers the surrounding module.

Update `tests/test_dispatchers.py` for the wider claw: the affected
tests previously fed non-tuple `variables` or non-callable `func`
arguments and relied on internal validation to catch them. Pass
correctly-typed arguments instead, and assert
`BeartypeCallHintViolation` for the genuinely-bad-literal case where
internal validation is now unreachable.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Rewrite `_beartype_conf.py` and `tests/test_beartype_claw.py` headers
to describe the current setup: a single package-wide claw on `lcm`
with `INTERNAL_CONF`, layered with explicit `@beartype(conf=...)`
decorators on user-facing constructors that map violations to project
exceptions. Drop the unused `REGIME_BUILDING_CONF`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker changed the title Adopt beartype runtime type checking: perimeter + scoped claw + annotation cleanup Adopt beartype runtime type checking across lcm May 15, 2026
hmgaudecker and others added 8 commits May 15, 2026 08:26
Split the overloaded `"regime"` string into two semantically explicit names:

- Input boundary (`initial_conditions` dict key, DataFrame column for
  `model.simulate`): renamed to `"regime_id"`. Affects every
  `initial_conditions["regime"]` lookup, `{"regime": ...}` dict literal
  in initial-conditions context, the `if name == "regime"` branch in
  `canonicalize_initial_conditions`, and every `Literal["regime"]` type
  annotation.
- Output DataFrame column (produced by `SimulationResult.to_dataframe`):
  renamed to `"regime_name"`. Affects the column-construction site,
  `dtypes["regime"]`, the categorical conversion branch, the column
  list, the post-`to_dataframe` masks/filters in tests, and prose
  references to "the regime column".

Regression pickles in tests/data/ have their `regime` column renamed to
`regime_name` to match the new `to_dataframe()` output.
The DataFrame side of `initial_conditions_from_dataframe` carries
regime-label strings, mirroring how `DiscreteGrid` state columns
already carry string category labels. The dict-side `"regime_id"`
key continues to hold integer codes; only the DataFrame column
name changes.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…trings

The dtype-follows-grid rule is already visible in the type annotation.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The internal dict is already typed as
`dict[StateName | Literal["regime_id"], FloatND | IntND]`; the
signature was widening it to bare `str` keys.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Mirrors the UserParams / Params split:
- UserInitialConditions: boundary form accepted by `Model.simulate`
  and `canonicalize_initial_conditions` (Array | np.ndarray values).
- InitialConditions: post-canonicalization form emitted by
  `canonicalize_initial_conditions` / `initial_conditions_from_dataframe`
  and consumed by `validate_initial_conditions`, `simulate`, and
  persistence (FloatND | IntND values).

Both aliases are read-protocol (`Mapping[...]`) so callers can pass a
plain dict; pylcm producers still wrap returns in `MappingProxyType`
at runtime.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Pin `active_periods` array to `int32` in `_validate_initial_regime_active`
  (matches the int32-period invariant established for sibling arrays).
- Hoist `lcm.params` imports in `lcm.utils.containers` to module top; the
  dual lazy imports inside `mapping_leaf` / `sequence_leaf` keep the cycle
  resolved.
- Drop the redundant class-level `@beartype` on `LogSpacedGrid` (the
  `__init__`-level decorator already covers it; matches `LinSpacedGrid`).
- Delete the dead `ValueError` branch in `vmap_1d` (unreachable under the
  package claw's `Literal` check); the test asserts the claw violation
  directly.
- Replace ad-hoc `_persistence.Model = Model` injection in `lcm.__init__`
  with `_bind_forward_refs` helpers in `lcm.persistence` and `lcm.variables`;
  add a regression test that the bindings survive a fresh import.
- Tighten `test_beartype_claw.py` test docstrings to behaviour sentences.
- Reconcile `UserMappingLeaf` "frozen" docstring with `__hash__ = None`.
- Add `UserAge = int | Fraction` alias; sweep `AgeGrid.__init__`,
  `_validate_age_grid`, `_validate_range`, `_validate_values` to use it.
- Trim Claudish prose in the `lcm.__init__` claw block and the
  `_emit_post_loop_diagnostics` / `_raise_first_nan_row` docstrings.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Per-subject vectors carry a single broadcast axis (length n_subjects).
The validator already enforces rank-1 by checking equal lengths across
arrays; tightening the type makes the contract explicit and lets
beartype catch rank slips at the canonicalize boundary.

- `InitialConditions` alias narrows from `FloatND | IntND` to
  `Float1D | Int1D`.
- `RegimeStates` alias likewise narrows; `build_initial_states`,
  `_advance_states_for_subjects`, and the local intermediate in
  `pandas_utils.initial_conditions_from_dataframe` follow.
- `UserInitialConditions` stays wide at the boundary so user-supplied
  NumPy arrays of any rank pass through to canonicalization.

Also bump the pylcm-benchmarks pin of aca-model from 6e282a5
(pre-regime rename) to bce9101 — the CI run failed with
`KeyError: 'regime_id'` because the cached benchmark snapshot still
emitted the old `"regime"` key.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous tightening commit forgot to stage pixi.lock; CI's
locked install failed on the stale rev.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Under cuda12 + 32-bit precision + cumulative test-suite state, the
package claw's deep-check on `Sequence[float]` fires inside this inner
helper before the manual `isinstance(p, int | float)` loop can raise
the user-facing `GridInitializationError`. Reproduced locally with
`pixi run -e tests-cuda12 tests-32bit`; CI's GPU 32-bit job hit the
same.

Setting the element type to `Any` tells beartype not to deep-check,
restoring the manual loop as the single source of truth for the
"non-numeric points" error.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice cleanups and very cool what beartype brings to the table here!

I actually don't have any comments.

hmgaudecker and others added 2 commits May 18, 2026 07:15
Three small fixes:

- `dispatchers.py`: `vmapped.__signature__ = signature` gets a
  `ty: ignore[invalid-assignment]` — the runtime mutation is
  intentional; the `FunctionWithArrayReturn` TypeVar bound doesn't
  declare `__signature__` as writable. Same TypeVar bound also lacks
  `__name__`, so the error-message access uses `getattr(...)` with a
  `repr(func)` fallback.
- `typing.ActiveFunction`: `age: int | float` is too strict — contravariance
  means it forbids model authors from annotating `age: int` (annual grids)
  or `age: float` (sub-annual). Switch to `age: Any` (matching the
  `UserFunction` precedent) so authors can pin the annotation to whichever
  type matches their grid.
- `mahler_yum_2024._model`: `final_age_alive` / `initial_age` are
  derived from `ages.values[...]` which jaxtyping reports as `Array`;
  wrap with `int(...)` to land on a plain Python int (correct for the
  `step="2Y"` grid this model uses).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 8689b41 into main May 18, 2026
8 of 10 checks passed
@hmgaudecker hmgaudecker deleted the feat/beartype-perimeter branch May 18, 2026 06:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: Improve type annotations and checking

3 participants