Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
1c65f45
Support runtime-supplied points on continuous-action IrregSpacedGrids
hmgaudecker Apr 29, 2026
1792279
benchmarks: bump aca-model to runtime-consumption-points version
hmgaudecker Apr 29, 2026
cf00e99
Fail loudly when reading runtime IrregSpacedGrid before substitution
hmgaudecker Apr 29, 2026
db98cde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
03ba800
Fix remaining ruff check errors in test_single_feasible_action.py
hmgaudecker Apr 29, 2026
3769d2d
Raise on regime-function-output indexed by discrete state in a consumer
hmgaudecker Apr 29, 2026
282542f
benchmarks: bump aca-model to dead-regime-NaN fix
hmgaudecker Apr 29, 2026
72c83f7
Substitute runtime-supplied action gridpoints in simulate's state-act…
hmgaudecker Apr 29, 2026
db6214f
Guard log-spaced grids against non-positive start; reject non-finite …
hmgaudecker Apr 30, 2026
f56b9be
Address PR review: docstrings, type hints, validation, drop separator…
hmgaudecker Apr 30, 2026
3b7be82
LogSpacedGrid docstring: drop redundant rationale
hmgaudecker May 1, 2026
8589146
IrregSpacedGrid.to_jax docstring: shorter, point at the substituted-g…
hmgaudecker May 1, 2026
541392c
IrregSpacedGrid.to_jax error message: same shape as docstring
hmgaudecker May 1, 2026
54d22b0
Use jnp.isfinite in grid validators; drop math import
hmgaudecker May 1, 2026
ba38876
Drop cryptic aca_model reference from validator docstring
hmgaudecker May 1, 2026
4d11dea
Validator: also flag function-output indexed by a derived categorical
hmgaudecker May 1, 2026
01608bb
create_regime_state_action_space docstring: trim rationale
hmgaudecker May 1, 2026
6241056
state_action_space: move private helpers below public function (deep …
hmgaudecker May 1, 2026
2efe9e1
Drop application-specific (aca) references from test docstrings
hmgaudecker May 1, 2026
f1c4d5d
Validator: rename to discrete_grid_names, also include discrete actions
hmgaudecker May 1, 2026
d9f37ce
Validator: tighten to actual footgun shape; correct behaviour descrip…
hmgaudecker May 1, 2026
2cc46ff
solve_brute: stream NaN/Inf reductions instead of stacking-and-flushing
hmgaudecker May 1, 2026
365da07
solve_brute: stop pinning per-period V templates in diagnostic_rows
hmgaudecker May 1, 2026
bf1cdf4
solve_brute: fail-fast on NaN per period; rewrite stale diagnostic hint
hmgaudecker May 4, 2026
c7745f3
solve/simulate: surface snapshot path in NaN exception note
hmgaudecker May 4, 2026
bc067c1
Model.n_subjects: AOT-compile simulate functions for fixed batch shape
hmgaudecker May 1, 2026
8bb8259
simulate AOT: match runtime's sparse pytree, drop runtime padding
hmgaudecker May 2, 2026
54c72a0
bench_aca_baseline: pass n_subjects=_N_SUBJECTS to create_benchmark_m…
hmgaudecker May 2, 2026
92d038c
benchmarks: bump aca-model rev to carry n_subjects on factories
hmgaudecker May 2, 2026
648afcc
benchmarks: bump aca-model rev + pass max_consumption to factory
hmgaudecker May 2, 2026
596f150
simulate AOT: re-jit `next_state` / `compute_regime_transition_probs`
hmgaudecker May 2, 2026
9fd0524
simulate AOT: only compile active-period argmax, not the full age range
hmgaudecker May 3, 2026
dfb0e8b
simulate AOT: int32 for discrete state lower-args (match runtime)
hmgaudecker May 3, 2026
43723c4
build_initial_states: cast discrete states to grid dtype (one-shot)
hmgaudecker May 3, 2026
99f3f15
DiscreteGrid: pin to_jax() to int32 regardless of x64 mode
hmgaudecker May 3, 2026
458d36f
Lock integer dtype to int32 end-to-end (#341)
hmgaudecker May 3, 2026
e8ede00
benchmarks: bump aca-model rev; drop max_consumption kwarg
hmgaudecker May 3, 2026
4316b6e
solve_brute: merge resolved_fixed_params into NaN diagnostic regime_p…
hmgaudecker May 3, 2026
866a5bb
benchmarks: bump aca-model rev to 714fee0 (assets-floor margin)
hmgaudecker May 4, 2026
a51edae
regime_template: exempt next_<state> names from fixed_param extraction
hmgaudecker May 4, 2026
71a6146
Merge feature/next-state-deps-in-transitions: exempt next_<state> fro…
hmgaudecker May 4, 2026
ac93eec
Merge improve/lazy-solve-diagnostics (incl. next_<state> exempt-set fix)
hmgaudecker May 4, 2026
d66d85a
Boilerplate refresh: dags module, current pixi/hook pins, drop stale …
hmgaudecker May 4, 2026
f18d27c
Bump .ai-instructions: pyproject-fmt + ruff + check-jsonschema rev pins
hmgaudecker May 4, 2026
5e01de6
Merge feature/runtime-action-grids: boilerplate refresh + ai-instruct…
hmgaudecker May 4, 2026
f730e78
Merge feature/next-state-deps-in-transitions: boilerplate refresh + a…
hmgaudecker May 4, 2026
749f83a
Merge improve/lazy-solve-diagnostics: boilerplate refresh + ai-instru…
hmgaudecker May 4, 2026
c3e1838
Merge main: pick up #338 (runtime action grids + validator tightening)
hmgaudecker May 4, 2026
01ba1f3
Merge feature/next-state-deps-in-transitions (carries main → #338)
hmgaudecker May 4, 2026
3229a15
Merge improve/lazy-solve-diagnostics (carries main → #338 → #342)
hmgaudecker May 4, 2026
e45cf4f
regime_template: reject next_<state> params on regular DAG functions
hmgaudecker May 4, 2026
5ff174d
Merge feature/next-state-deps-in-transitions: harden next_<state> val…
hmgaudecker May 4, 2026
5bae789
Merge improve/lazy-solve-diagnostics: harden next_<state> validator
hmgaudecker May 4, 2026
c585f4b
Bump aca-model benchmark pin to 83f22500 (post pension correction)
hmgaudecker May 4, 2026
ffe8820
Merge feature/next-state-deps-in-transitions: bump aca-model benchmar…
hmgaudecker May 4, 2026
e912de4
Merge improve/lazy-solve-diagnostics: bump aca-model benchmark pin to…
hmgaudecker May 4, 2026
f0dd7b5
Revert aca-model pin: this branch lacks Model.n_subjects (introduced …
hmgaudecker May 4, 2026
149ac78
Merge feature/next-state-deps-in-transitions: revert aca-model pin to…
hmgaudecker May 4, 2026
588c9c4
Merge improve/lazy-solve-diagnostics
hmgaudecker May 4, 2026
c00b610
Bump aca-model pin to 83f22500 on #340 (carries pension correction)
hmgaudecker May 4, 2026
e89d5e4
Revert "regime_template: reject next_<state> params on regular DAG fu…
hmgaudecker May 4, 2026
d07d897
Merge feature/next-state-deps-in-transitions: revert validator that c…
hmgaudecker May 4, 2026
53443fc
Merge improve/lazy-solve-diagnostics: revert validator + carry over
hmgaudecker May 4, 2026
e92eeec
Bump aca-model pin to 3453080 (filters stale benchmark_params key)
hmgaudecker May 4, 2026
c117f4c
Bump aca-model pin to b2e90bb (synthesise shifted imputation arrays)
hmgaudecker May 4, 2026
e422876
Bump aca-model pin to 35eddcc (declare target_his derived categorical)
hmgaudecker May 4, 2026
4b9bea3
Bump aca-model pin to 64d6567 (rename shifted-array level to target_his)
hmgaudecker May 4, 2026
9c7edb6
Bump aca-model pin to f09b5e3 (per-target next_assets, dead-target te…
hmgaudecker May 4, 2026
e6066fa
Roll #340 aca-model pin back to 63d2a38 (pre-pension-correction)
hmgaudecker May 4, 2026
a908c84
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
18d4ade
Revert aca-model rollback: restore f09b5e3 pin (with pension correction)
hmgaudecker May 5, 2026
6c64a77
get_next_state_function_for_simulation: per-target DAG mirrors solve
hmgaudecker May 5, 2026
fed28cd
Merge feature/next-state-deps-in-transitions: pull pylcm simulate-pat…
hmgaudecker May 5, 2026
c969b1a
next_state: real signature for combined; fix trivially-passing test
hmgaudecker May 6, 2026
8a2ad4f
Address #342 review: simulate-path uses concatenate_functions; cleanups
hmgaudecker May 6, 2026
bcdd358
Merge feature/next-state-deps-in-transitions: address #342 review fee…
hmgaudecker May 6, 2026
17347c8
Get rid of H_variables entirely in regime_template.
hmgaudecker May 6, 2026
076b9b6
Bump .ai-instructions: TDD always; behavior-focused docstrings
hmgaudecker May 6, 2026
e6f8e41
Merge feature/next-state-deps-in-transitions: bump .ai-instructions f…
hmgaudecker May 6, 2026
164a88b
AGENTS.md: inline TDD-always testing section directly
hmgaudecker May 6, 2026
07c5ae0
Merge feature/next-state-deps-in-transitions: inline TDD section in A…
hmgaudecker May 6, 2026
5261e29
Address #339 review: drop field-count test; tighten claudish docstrings
hmgaudecker May 6, 2026
110cc0b
regime_template: collapse H_variables into single variables set
hmgaudecker May 6, 2026
4b2895c
Merge feature/next-state-deps-in-transitions: collapse H_variables
hmgaudecker May 6, 2026
a7b9e9a
solve_brute: drop misleading "~2 MB each" magic number from comment
hmgaudecker May 6, 2026
16b570f
AGENTS.md: docstring style — describe state, no PR refs, bulleted lists
hmgaudecker May 6, 2026
1251f0b
Merge feature/next-state-deps-in-transitions: docstring style + .ai-i…
hmgaudecker May 6, 2026
ff65261
solve_brute: apply docstring style — drop PR ref, magic number; bulle…
hmgaudecker May 6, 2026
2539374
Merge improve/lazy-solve-diagnostics: docstring style, .ai-instructio…
hmgaudecker May 6, 2026
e9b7cc5
test_next_state: update to assert nested output shape from #339 merge
hmgaudecker May 6, 2026
eb69432
Merge main: pick up squash of #342
hmgaudecker May 6, 2026
838473e
validate_initial_conditions: per-constraint admissibility in error me…
hmgaudecker May 6, 2026
07f951a
Merge improve/lazy-solve-diagnostics: pick up main-squash + downstrea…
hmgaudecker May 6, 2026
e4cae2a
_per_constraint_feasibility: filter args per single-constraint feasib…
hmgaudecker May 6, 2026
62392c1
Merge branch 'main' into feat/simulate-aot-n-subjects
hmgaudecker May 6, 2026
50f78f0
Address #340 review: docstring style, TDD, x64+AOT guard, period dtype
hmgaudecker May 6, 2026
7e81532
Package A: int dtype barriers at the API boundary
hmgaudecker May 6, 2026
e881313
Package B: float dtype barriers at the API boundary
hmgaudecker May 6, 2026
63b740c
Fix Package A 32-bit precision tests: build overflow fixtures with numpy
hmgaudecker May 6, 2026
2ee56ca
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 6, 2026
ef180d0
Fix Package B 32-bit precision test: build float overflow fixture wit…
hmgaudecker May 6, 2026
568cbcb
Address #340 review-2: counterfactuals, multi-assertion tests, dedup …
hmgaudecker May 6, 2026
1a42ffb
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 6, 2026
2f19aa1
compile: free lower-args after lowering, free Lowered after compile
hmgaudecker May 7, 2026
75c8d25
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 7, 2026
143a3ae
solve_brute: rename diag_params to effective_regime_params
hmgaudecker May 7, 2026
1947560
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 7, 2026
09f3d03
Address PR #345 review
hmgaudecker May 7, 2026
b8dc490
bench_aca_baseline: pass pref_type_grid to create_benchmark_model
hmgaudecker May 7, 2026
3c3af21
Merge cleanup/aca-bench-no-defaults
hmgaudecker May 7, 2026
cbe65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2026
a6b0ac1
bench_aca_baseline: hoist aca_model + lcm imports to module top
hmgaudecker May 8, 2026
1e26926
process_params: cast Python int leaves to jnp.int32
hmgaudecker May 8, 2026
381dc25
Merge feat/simulate-aot-n-subjects + reorder reshape-before-cast
hmgaudecker May 8, 2026
9d26643
_validate_param_types: drop dead branches post-whitelist
hmgaudecker May 8, 2026
530d50c
Tighten Scalar* aliases to JAX-only; convert grid endpoints at constr…
hmgaudecker May 8, 2026
f4515ec
Keep coordinate helpers strict; convert at Grid.get_coordinate boundary
hmgaudecker May 8, 2026
9fc2f49
bench_aca_baseline: build on CPU to keep parent process CUDA-free
hmgaudecker May 8, 2026
88a85ae
save_simulate_snapshot: strip AOT-compiled regimes before pickling re…
hmgaudecker May 8, 2026
f2d18fa
bench_aca_baseline: defer aca_model + lcm imports back into _build
hmgaudecker May 8, 2026
fff1537
Tighten internal types: ScalarInt n_points, JAX-only Period/Age, kw-only
hmgaudecker May 8, 2026
aea8735
linspace/logspace: drop int(n_points) cast in favour of ty:ignore
hmgaudecker May 8, 2026
f4069c1
Piecewise n_points: sum the cached _piece_n_points array
hmgaudecker May 8, 2026
bf12b61
benchmarks: bump aca-model pin to 67edfe0f
hmgaudecker May 8, 2026
1bae789
simulate: keep period: int through the loop, cast at the JIT boundary
hmgaudecker May 8, 2026
14f81fc
solve: kick off simulate AOT compile in a background thread
hmgaudecker May 8, 2026
2f486dc
Merge branch 'feat/simulate-aot-n-subjects' into feat/canonical-float…
hmgaudecker May 8, 2026
c419377
benchmarks: bump aca-model pin to d9339ab
hmgaudecker May 8, 2026
61c2436
simulate orchestrates simulate-AOT compile, not solve
hmgaudecker May 8, 2026
1deed36
tests: drop noqa: ARG001 + collapse x64-fixture signatures
hmgaudecker May 9, 2026
00f3b4a
test_next_state: pass JAX scalars instead of ty:ignore-ing Python ones
hmgaudecker May 9, 2026
ca66ba9
simulate: swap AOT-compiled regimes for lazy ones on the result
hmgaudecker May 9, 2026
7547ac3
simulate: AOT-compile blocks before solve to avoid contention
hmgaudecker May 10, 2026
7d7d2a0
pytest: promote 'Explicitly requested dtype' UserWarning to an error
hmgaudecker May 11, 2026
6417124
Merge origin/main into feat/canonical-float-dtype: subsume the squash…
hmgaudecker May 11, 2026
5b49bea
Complete merge of origin/main; pin aca-model to 9ac2043
hmgaudecker May 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions benchmarks/bench_aca_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,30 @@


def _build() -> tuple[object, object, object]:
"""Build the aca-baseline model, params, and initial conditions."""
"""Build the aca-baseline model, params, and initial conditions.

aca_model and lcm imports are deferred to the function body — ASV's
forkserver runs `preimport` to discover benchmarks across every
`bench_*.py` module before forking workers. Importing JAX at module
top loads the multithreaded XLA backend into the forkserver; every
subsequent `os.fork()` inherits a corrupted CUDA context and the
first device op in the worker aborts with
`CUDA_ERROR_NOT_INITIALIZED`. Per-call imports keep JAX out of the
forkserver and confine it to the worker process.
"""
from aca_model.agent.preferences import BenchmarkPrefType
from aca_model.benchmark import (
create_benchmark_model,
get_benchmark_initial_conditions,
get_benchmark_params,
)

model = create_benchmark_model(n_subjects=_N_SUBJECTS)
from lcm import DiscreteGrid

model = create_benchmark_model(
n_subjects=_N_SUBJECTS,
pref_type_grid=DiscreteGrid(BenchmarkPrefType),
)
_, model_params = get_benchmark_params(model=model)
initial_conditions = get_benchmark_initial_conditions(
model=model, n_subjects=_N_SUBJECTS, seed=0
Expand Down
8 changes: 4 additions & 4 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 19 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ tests-cuda13 = { features = [ "tests", "cuda13" ], solve-group = "cuda13" }
tests-metal = { features = [ "tests", "metal" ], solve-group = "metal" }
type-checking = { features = [ "type-checking", "tests" ], solve-group = "default" }
[tool.pixi.feature.benchmarks.pypi-dependencies]
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "f09b5e34102ff42f739b95be5a9d388795b734a1" }
aca-model = { git = "https://github.com/OpenSourceEconomics/aca-model.git", rev = "9ac20430f499a8b1cdb056af85bc2a26e850bad2" }
[tool.pixi.feature.cuda12]
platforms = [ "linux-64" ]
system-requirements = { cuda = "12" }
Expand Down Expand Up @@ -242,6 +242,15 @@ per-file-ignores."tests/*" = [
"S301", # Use of pickle
"SLF001", # Private member access
]
per-file-ignores."tests/test_dtypes.py" = [
"ARG001", # Unused function argument (x64_enabled / x64_disabled fixtures)
]
per-file-ignores."tests/test_explicit_dtype_filter.py" = [
"ARG001", # Unused function argument (x64_disabled fixture)
]
per-file-ignores."tests/test_float_dtype_invariants.py" = [
"ARG001", # Unused function argument (x64_disabled fixture)
]
per-file-ignores."tests/test_next_state.py" = [
"ARG001", # Unused function argument
"ARG005", # Unused lambda argument
Expand Down Expand Up @@ -294,7 +303,15 @@ ini_options.addopts = [
"--dist",
"loadfile",
]
ini_options.filterwarnings = []
ini_options.filterwarnings = [
# JAX emits this UserWarning when user code asks for a dtype wider
# than the active x64 setting allows. Under `--precision=32` it
# surfaces every stray `jnp.int64` / `jnp.float64` / `dtype=int64`
# literal in src/ — the only files that legitimately trigger it are
# the dtype-invariant test modules, which opt out via a local
# `pytestmark` filter.
"error:Explicitly requested dtype.*:UserWarning",
]
ini_options.markers = [
"illustrative: Tests are designed for illustrative purposes",
"gpu: Tests that require a GPU (skipped on CPU-only machines)",
Expand Down
14 changes: 8 additions & 6 deletions src/lcm/ages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.numpy as jnp

from lcm.exceptions import GridInitializationError, format_messages
from lcm.typing import Age, Float1D, Int1D
from lcm.typing import Float1D, Int1D

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use the Age Type anymore?


STEP_UNITS: MappingProxyType[str, Fraction] = MappingProxyType(
{
Expand Down Expand Up @@ -129,7 +129,7 @@ def exact_step_size(self) -> int | Fraction | None:
"""
return self._exact_step_size

def period_to_age(self, period: int) -> Age:
def period_to_age(self, period: int) -> int | float:
"""Convert a period index to the corresponding age.

Args:
Expand All @@ -151,7 +151,7 @@ def period_to_age(self, period: int) -> Age:
return int(self._values[period])
return float(self._values[period])

def age_to_period(self, age: Age) -> int:
def age_to_period(self, age: float) -> int:
"""Convert an age to the corresponding period index.

Args:
Expand All @@ -172,12 +172,14 @@ def age_to_period(self, age: Age) -> int:
raise ValueError(msg) from None

@functools.cached_property
def _age_to_period_map(self) -> dict[Age, int]:
def _age_to_period_map(self) -> dict[int | float, int]:
if self._is_integer:
return {int(v): i for i, v in enumerate(self._exact_values)}
return {float(v): i for i, v in enumerate(self._exact_values)}

def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]:
def get_periods_where(
self, predicate: Callable[[int | float], bool]
) -> tuple[int, ...]:
"""Get period indices where predicate is True.

Args:
Expand All @@ -187,7 +189,7 @@ def get_periods_where(self, predicate: Callable[[Age], bool]) -> tuple[int, ...]
Tuple of period indices where predicate(age) is True.

"""
_convert: Callable[[object], Age] = int if self._is_integer else float # ty: ignore[invalid-assignment]
_convert: Callable[[object], int | float] = int if self._is_integer else float # ty: ignore[invalid-assignment]
return tuple(
period
for period in range(self.n_periods)
Expand Down
60 changes: 54 additions & 6 deletions src/lcm/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
Used at every API boundary that accepts user data (params, initial
conditions, regime-id arrays) — always called from Python, never inside
JIT. Each helper validates that the value fits the target dtype and
raises a clearly-named error if not.

Casts further down the simulate stack (e.g. transition outputs landing
in the state pool) use plain `.astype` and rely on the boundary cast
above them having already pinned the canonical dtype.
raises a clearly-named error if not. Once an input has crossed the
boundary it carries the canonical dtype unchanged through the simulate
stack; downstream code does not re-cast.
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax import Array

_INT32_MIN = int(np.iinfo(np.int32).min)
_INT32_MAX = int(np.iinfo(np.int32).max)
_FLOAT32_MAX = float(np.finfo(np.float32).max)


def canonical_float_dtype() -> jnp.dtype:
"""Return pylcm's canonical float dtype, derived from `jax_enable_x64`.

Returns `jnp.float64` if `jax.config.jax_enable_x64` is True,
otherwise `jnp.float32`. The value is read at call time, not at
import, so toggling the JAX config (e.g. between tests) is honoured.
"""
return jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32


def safe_to_int32(value: object, *, name: str) -> Array:
def safe_to_int_dtype(value: object, *, name: str) -> Array:
"""Cast a scalar, sequence, or array to `jnp.int32`, checking int32 range.

Args:
Expand Down Expand Up @@ -46,3 +56,41 @@ def safe_to_int32(value: object, *, name: str) -> Array:
)
raise ValueError(msg)
return jnp.asarray(np_value, dtype=jnp.int32)


def safe_to_float_dtype(value: object, *, name: str) -> Array:
"""Cast a scalar, sequence, or array to the canonical float dtype.

Range check fires only on a down-cast:

- Down-cast (float64 → float32 under `jax_enable_x64=False`): raise
`OverflowError` if any element exceeds float32 magnitude rather
than letting JAX silently saturate to ``±inf``.
- Up-cast or same-width cast: skip the range check. Precision loss
within range is not an error — it is an inherent consequence of
`jax_enable_x64=False`.

Args:
value: A Python float, numpy/JAX scalar, or array-like.
name: Qualified name of the leaf — surfaced in the error message.

Returns:
A JAX array at `canonical_float_dtype()` (0-d if `value` was a
scalar).

Raises:
OverflowError: If down-casting to `float32` would saturate any
element to `±inf`. The message names the leaf via `name`.

"""
target_dtype = canonical_float_dtype()
np_value = np.asarray(value)
if target_dtype == jnp.float32 and np_value.size > 0:
max_mag = float(np.max(np.abs(np_value)))
if max_mag > _FLOAT32_MAX:
msg = (
f"{name}: float32 overflow — max |value| {max_mag:g} "
f"exceeds float32 max {_FLOAT32_MAX:g}."
)
raise OverflowError(msg)
return jnp.asarray(np_value, dtype=target_dtype)
Loading
Loading