Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
9f3937b
First changes
mj023 May 4, 2026
2cab0af
Merge branch 'main' into distributed
mj023 May 4, 2026
e3bd7e4
Add second distribution pattern
mj023 May 8, 2026
1b2baa3
Add parallelization across multiple devices during solve (#346)
hmgaudecker May 11, 2026
c1aa68d
Merge branch 'distributed' of https://github.com/OpenSourceEconomics/…
mj023 May 11, 2026
7eeaf37
Fix AOT + Add Simulation
mj023 May 11, 2026
8389fc3
Merge branch 'main' into distributed
hmgaudecker May 12, 2026
da3b3f0
Fix tests
mj023 May 12, 2026
1f3a975
Merge branch 'distributed' of https://github.com/OpenSourceEconomics/…
mj023 May 12, 2026
33311ea
Fix Typing and Tests
mj023 May 12, 2026
5aa62dc
Add error to simulate
mj023 May 12, 2026
dc6cf79
Phase 1: flip simulation state carrier to nested StatesPerRegime
hmgaudecker May 13, 2026
9c0a7f6
Phase 2: rewrite qname introspection + structured stochastic factories
hmgaudecker May 13, 2026
9b7f8a2
Refactor: simplify state-carrier names, tighten StateName annotations
hmgaudecker May 13, 2026
c786e55
Merge remote-tracking branch 'origin/feat/states-per-regime' into fea…
hmgaudecker May 13, 2026
914d9d6
Review feedback: prefer f-strings for encode, clarify state-name naming
hmgaudecker May 13, 2026
a5d932a
Replace variable_info DataFrame with typed mapping
hmgaudecker May 13, 2026
77d2131
Merge branch 'main' into feat/nested-dag-namespaces
hmgaudecker May 13, 2026
1b587ca
Merge remote-tracking branch 'origin/feat/nested-dag-namespaces' into…
hmgaudecker May 13, 2026
784e90c
Merge branch 'main' into distributed
mj023 May 13, 2026
1ad392a
Promote variable_info to a Variables container with named views
hmgaudecker May 13, 2026
d503dae
V.py: use Variables.{discrete,continuous}_state_names directly
hmgaudecker May 13, 2026
607dba5
Drop intermediate variables and one-call helpers
hmgaudecker May 13, 2026
d1e8dae
Merge remote-tracking branch 'origin/main' into feat/variable-info-ty…
hmgaudecker May 13, 2026
80d6095
Add beartype dep + per-exception BeartypeConf helpers
hmgaudecker May 13, 2026
3e0d621
Decorate Regime + MarkovTransition with @beartype
hmgaudecker May 13, 2026
4799dff
Decorate Model with @beartype; trim type-check block of validate_mode…
hmgaudecker May 13, 2026
24029ec
Decorate grids + shocks + categorical with @beartype_init
hmgaudecker May 13, 2026
b0f9fb8
Decorate Model.solve + Model.simulate with @beartype
hmgaudecker May 13, 2026
ca69bbd
Dont run dist code if no grid distributed
mj023 May 13, 2026
04294e8
Merge branch 'main' into distributed
mj023 May 13, 2026
67e64d3
Remove distributed states from state space
mj023 May 13, 2026
2034263
Merge branch 'distributed' of https://github.com/OpenSourceEconomics/…
mj023 May 13, 2026
45ea93a
Annotation drift sweep — prep for whole-package beartype claw
hmgaudecker May 13, 2026
51a7b4b
Enable scoped beartype claw on lcm.grids/shocks/params
hmgaudecker May 13, 2026
c6d5078
Fix sharding correctness in solve and forbid distributed actions
hmgaudecker May 13, 2026
0e46529
Hoist non-parameter field set on shock grids
hmgaudecker May 13, 2026
984fac0
Fix DiscreteGrid.distributed and abstract docstring
hmgaudecker May 13, 2026
e107250
Add KeyArray alias for shock-sampler PRNGKey parameters
hmgaudecker May 13, 2026
80fc815
Merge branch 'distributed' into feat/variable-info-typed-mapping
hmgaudecker May 13, 2026
b96f4ce
Merge branch 'feat/variable-info-typed-mapping' into feat/beartype-pe…
hmgaudecker May 13, 2026
cbeaeb2
Merge branch 'feat/beartype-perimeter' into feat/beartype-claw-cleanup
hmgaudecker May 13, 2026
6e34efb
Rename v_array → V_arr for consistency across the solve path
hmgaudecker May 13, 2026
e098d53
Merge remote-tracking branch 'origin/distributed' into feat/variable-…
hmgaudecker May 13, 2026
8892728
Merge branch 'feat/variable-info-typed-mapping' into feat/beartype-pe…
hmgaudecker May 13, 2026
e0a1694
Merge branch 'feat/beartype-perimeter' into feat/beartype-claw-cleanup
hmgaudecker May 13, 2026
59aca3f
Activate beartype claw at lcm.__init__ + fix surfaced drift
hmgaudecker May 13, 2026
fda74bf
Tighten pylcm-internal types: drop float / Mapping widenings; cast at…
hmgaudecker May 13, 2026
f210a3c
Activate beartype claw on lcm.regime_building
hmgaudecker May 13, 2026
1a92dff
Merge remote-tracking branch 'origin/feat/beartype-claw-cleanup' into…
hmgaudecker May 14, 2026
37eb6e6
Drop _dags_forwarders shim; import dags wrappers directly
hmgaudecker May 14, 2026
17890f5
Narrow type hints repo-wide: drop redundant unions, scalar/array over…
hmgaudecker May 14, 2026
66c83d1
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
hmgaudecker May 14, 2026
43c0cd0
Tighten create_params_template input to MappingProxyType
hmgaudecker May 14, 2026
8d672dc
Eliminate remaining bare Array annotations across pylcm
hmgaudecker May 14, 2026
e9deac3
Make GPU-peak-mem subprocess parsing robust to stdout noise
hmgaudecker May 14, 2026
af44b27
Keep IntND at user boundaries; add RealND for ndimage/argmax primitives
hmgaudecker May 14, 2026
976ddd2
Drop RealND; ndimage/argmax use FloatND/IntND, tests pin int32
hmgaudecker May 14, 2026
342bce6
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
hmgaudecker May 14, 2026
bee52c8
Make jaxtyping "..." sentinel survive pickling; fix int64 ndimage uni…
hmgaudecker May 14, 2026
ce1aba1
Rename KeyArray to PRNGKeyND; fix mis-typed key params and regime-pro…
hmgaudecker May 14, 2026
e0bb357
Merge remote-tracking branch 'origin/feat/beartype-claw-extend' into …
hmgaudecker May 14, 2026
2ecd905
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
hmgaudecker May 14, 2026
adb99e4
Type RegimeTransitionFunction by its real return; PRNGKeyND via jaxty…
hmgaudecker May 14, 2026
da3d7de
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
hmgaudecker May 14, 2026
bfc7c35
Drop redundant @beartype_init decorators; the claw already covers gri…
hmgaudecker May 14, 2026
273ef76
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
hmgaudecker May 14, 2026
a59b3d6
Widen Model.simulate initial_conditions hint to accept pd.DataFrame.
hmgaudecker May 14, 2026
3371206
Activate beartype claw on lcm.solution and lcm.simulation
hmgaudecker May 14, 2026
3f8dd99
Pin __annotations__ on regime-transition-prob wrappers
hmgaudecker May 14, 2026
b6b9c48
Merge branch 'main' into feat/variable-info-typed-mapping
hmgaudecker May 14, 2026
ab9ae26
Merge branch 'feat/variable-info-typed-mapping' into feat/beartype-pe…
hmgaudecker May 14, 2026
249d0f4
Pin period arrays to int32 in initial-conditions validation
hmgaudecker May 14, 2026
fc704af
Bump benchmarks aca-model pin to feat/distributed-assets-grid tip
hmgaudecker May 14, 2026
3a0cf93
Merge remote-tracking branch 'origin/main' into feat/beartype-perimeter
hmgaudecker May 14, 2026
a38c1c9
Re-lock pixi.lock after aca-model pin bump and main merge
hmgaudecker May 14, 2026
8014d81
Pin period to int32 in regime-transition validation
hmgaudecker May 14, 2026
601cb47
Boilerplate: bump pixi-version to v0.68.1, update .ai-instructions
hmgaudecker May 14, 2026
01c84cd
Bump benchmarks aca-model pin to the average_consumption_equiv fix
hmgaudecker May 14, 2026
326b760
Claw lcm.utils.error_handling behind the construction perimeter
hmgaudecker May 14, 2026
2bef7ac
Claw lcm.state_action_space behind the construction perimeter
hmgaudecker May 14, 2026
7e7bd18
Claw lcm.interfaces behind the construction perimeter
hmgaudecker May 14, 2026
7b9183b
Claw lcm.regime behind the construction perimeter
hmgaudecker May 14, 2026
67ea957
Claw lcm.model behind the construction perimeter
hmgaudecker May 14, 2026
72cdfea
Fix discount_factor float-typed everywhere; drop NumericND
hmgaudecker May 15, 2026
4387640
Canonicalize shock-grid params at the boundary
hmgaudecker May 15, 2026
9b8124c
Split params leaves into User (boundary) and canonical variants
hmgaudecker May 15, 2026
fd32bd6
Sweep params signatures: tighten `FlatRegimeParams`, widen call proto…
hmgaudecker May 15, 2026
67dec68
Retype initial_conditions/initial_states with semantic keys
hmgaudecker May 15, 2026
2bd1664
Use StateOrActionName / ActionName aliases in state-and-action-keyed …
hmgaudecker May 15, 2026
a1e0a09
Add explicit @beartype to grid and shock public constructors
hmgaudecker May 15, 2026
28350e9
Collapse the perimeter-claw fan-out into a single lcm-package claw
hmgaudecker May 15, 2026
d6d4454
Refresh beartype claw docstrings for the unified package claw
hmgaudecker May 15, 2026
9bb6c3c
Rename "regime" to "regime_id" / "regime_name" by context
hmgaudecker May 15, 2026
97f4224
DataFrame initial-conditions column is `regime_name`, not `regime_id`
hmgaudecker May 15, 2026
b46f865
Drop redundant FloatND | IntND rationale from action/state field docs…
hmgaudecker May 15, 2026
1376cc2
Tighten initial_conditions_from_dataframe return type
hmgaudecker May 15, 2026
52b6a0a
Add UserInitialConditions / InitialConditions type aliases
hmgaudecker May 15, 2026
e54b802
Post-review cleanup sweep
hmgaudecker May 15, 2026
415a85c
Tighten InitialConditions / RegimeStates to Float1D | Int1D
hmgaudecker May 15, 2026
748caac
Re-lock for the benchmarks aca-model rev bump (bce9101)
hmgaudecker May 15, 2026
36a531c
Loosen _validate_irreg_spaced_grid element type to Any
hmgaudecker May 15, 2026
39c0f66
Use dags main.
hmgaudecker May 18, 2026
3070466
Clear ty diagnostics surfaced by the relock
hmgaudecker May 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ai-instructions
Submodule .ai-instructions updated 1 files
+56 −39 AGENTS.md
10 changes: 5 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.67.2
pixi-version: v0.68.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cpu
Expand Down Expand Up @@ -59,7 +59,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.67.2
pixi-version: v0.68.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: type-checking
Expand All @@ -82,7 +82,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.67.2
pixi-version: v0.68.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -101,7 +101,7 @@ jobs:
- uses: actions/checkout@v6
- uses: prefix-dev/setup-pixi@v0.9.5
with:
pixi-version: v0.67.2
pixi-version: v0.68.1
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: tests-cuda12
Expand All @@ -116,7 +116,7 @@ jobs:
# - uses: actions/checkout@v6
# - uses: prefix-dev/setup-pixi@v0.9.5
# with:
# pixi-version: v0.67.2
# pixi-version: v0.68.1
# cache: true
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
# environments: tests-cpu
Expand Down
19 changes: 12 additions & 7 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ loaded = SimulationResult.from_pickle("path/to/file.pkl")

### Initial Conditions Format

Initial conditions use a flat dictionary with state names plus `"regime"`:
Initial conditions use a flat dictionary with state names plus `"regime_id"`:

```python
initial_conditions = {
"wealth": jnp.array([1.0, 2.0, 3.0]),
"health": jnp.array([0.5, 0.8, 0.3]),
"regime": jnp.array([RegimeId.working, RegimeId.working, RegimeId.retired]),
"regime_id": jnp.array([RegimeId.working, RegimeId.working, RegimeId.retired]),
}
```

Expand Down Expand Up @@ -527,11 +527,16 @@ Code structure should be self-evident from function names and ordering.
- **Helper function names follow `{verb}_{qualifier}_noun` patterns.** E.g.,
`get_irreg_coordinate`, `find_irreg_coordinate`, `get_linspace_coordinate` — not
`get_coordinate_irreg`.
- **Use `@overload` when a function accepts both scalar and array inputs.** When a
function works with both `ScalarFloat` and `Array`, add overload declarations so the
type checker can track `(ScalarFloat) -> ScalarFloat` and `(Array) -> Array`
separately. Concrete subclass methods need their own overloads too (not just the
abstract base).
- **Pick the single narrowest jaxtyping alias — never scalar/array `@overload` pairs,
never `ScalarX | XND` unions.** ty erases jaxtyping shape annotations: `ScalarFloat`,
`Float1D`, and `FloatND` all reveal as `Array`, so scalar/array `@overload` pairs and
`ScalarFloat | FloatND`-style unions add zero static precision — they are pure noise.
At runtime, beartype treats a 0-d float array as satisfying both `ScalarFloat` and
`FloatND`, so `ScalarFloat ⊆ FloatND` and the union is redundant. Annotate each slot
with the one alias that matches its genuine rank: `ScalarFloat`/`ScalarInt` for
fixed-0-d, `Float1D`/`Int1D` for fixed-1-d, `FloatND`/`IntND` for genuinely rank-
polymorphic. Never use a bare `Array` annotation — always reach for the narrowest
`lcm.typing` alias.
- **`func` for callable abbreviations** — use `func`, `func_name`, `func_params` (never
`fn`). Full word `function(s)` in dataclass field names and public method names.
- **Singular `state_names` / `action_names`** — not `states_names` / `actions_names`.
Expand Down
17 changes: 15 additions & 2 deletions benchmarks/_gpu_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class MahlerYumGpuPeakMem(GpuPeakMem):
# Project root: the directory containing the benchmarks/ package.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent

# Marks the peak-memory line on the subprocess's stdout. The subprocess imports
# lcm, whose beartype claw can emit diagnostics to stdout, so the parent locates
# this line instead of parsing stdout wholesale.
_PEAK_MARKER = "__PEAK_BYTES_IN_USE__"


def measure_gpu_peak(bench_module: str, bench_class: str) -> int:
"""Run a benchmark in a subprocess and return peak GPU bytes.
Expand Down Expand Up @@ -58,7 +63,15 @@ def measure_gpu_peak(bench_module: str, bench_class: str) -> int:
f"stderr: {result.stderr!r}"
)
raise RuntimeError(msg)
return int(result.stdout.strip())
for line in result.stdout.splitlines():
if line.startswith(_PEAK_MARKER):
return int(line.removeprefix(_PEAK_MARKER).strip())
msg = (
"GPU memory subprocess produced no peak-bytes line.\n"
f"stdout: {result.stdout!r}\n"
f"stderr: {result.stderr!r}"
)
raise RuntimeError(msg)


def _track_gpu_peak_mem(self):
Expand Down Expand Up @@ -104,4 +117,4 @@ def setup(self):
import jax

stats = jax.local_devices()[0].memory_stats()
print(stats["peak_bytes_in_use"])
print(f"{_PEAK_MARKER} {stats['peak_bytes_in_use']}")
2 changes: 1 addition & 1 deletion benchmarks/bench_mahler_yum.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _build(self):
self.model_params = {"alive": common_params}
self.initial_conditions = {
**initial_states,
"regime": jnp.full(
"regime_id": jnp.full(
_N_SUBJECTS,
self.model.regime_names_to_ids["alive"],
dtype=jnp.int32,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_precautionary_savings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _make_initial_conditions(n_subjects):
"age": jnp.full(n_subjects, 20.0),
"wealth": jnp.full(n_subjects, 5.0),
"income": jnp.full(n_subjects, 0.0),
"regime": jnp.zeros(n_subjects, dtype=jnp.int32),
"regime_id": jnp.zeros(n_subjects, dtype=jnp.int32),
}


Expand Down
2 changes: 1 addition & 1 deletion docs/development/benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class TimeMyModel:
self.model_params = my_model.get_params()
self.initial_conditions = {
"wealth": jnp.full(1_000, 5.0),
"regime": jnp.zeros(1_000, dtype=jnp.int32),
"regime_id": jnp.zeros(1_000, dtype=jnp.int32),
}

# JIT warmup (timed separately)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/mahler_yum_2024.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ result = MAHLER_YUM_MODEL.simulate(
params={"alive": params},
initial_conditions={
**initial_states,
"regime": jnp.full(
"regime_id": jnp.full(
n_subjects,
MAHLER_YUM_MODEL.regime_names_to_ids["alive"],
),
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/mortality.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ result = model.simulate(
initial_conditions={
"age": jnp.full(100, model.ages.values[0]),
"wealth": jnp.linspace(1, 100, 100),
"regime": jnp.full(100, model.regime_names_to_ids["working_life"]),
"regime_id": jnp.full(100, model.regime_names_to_ids["working_life"]),
},
period_to_regime_to_V_arr=None,
seed=1234,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/precautionary_savings.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ result = model.simulate(
"age": jnp.full(100, model.ages.values[0]),
"wealth": jnp.linspace(1, 10, 100),
"income": jnp.zeros(100),
"regime": jnp.full(100, model.regime_names_to_ids["alive"]),
"regime_id": jnp.full(100, model.regime_names_to_ids["alive"]),
},
period_to_regime_to_V_arr=None,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/precautionary_savings_health.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ result = model.simulate(
"age": jnp.full(1_000, model.ages.values[0]),
"wealth": jnp.full(1_000, 1.0),
"health": jnp.full(1_000, 1.0),
"regime": jnp.full(1_000, model.regime_names_to_ids["working_life"]),
"regime_id": jnp.full(1_000, model.regime_names_to_ids["working_life"]),
},
period_to_regime_to_V_arr=None,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tiny.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ params = get_params()

initial_df = pd.DataFrame(
{
"regime": "working_life",
"regime_name": "working_life",
"age": model.ages.values[0],
"wealth": np.linspace(1, 20, 100),
}
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TimeSolveSimulate:
self.initial_conditions = {
"age": jnp.full(500, 25.0),
"wealth": jnp.full(500, 5.0),
"regime": jnp.zeros(500, dtype=jnp.int32),
"regime_id": jnp.zeros(500, dtype=jnp.int32),
}

# --- JAX warmup --------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions docs/user_guide/pandas_interop.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ is DataFrame in, DataFrame out.
## Initial Conditions as a DataFrame

Pass a pandas DataFrame directly to `simulate()` as `initial_conditions`. One row per
agent, one column per state variable, plus a `"regime"` column:
agent, one column per state variable, plus a `"regime_name"` column:

```python
df = pd.DataFrame(
{
"regime": ["working", "working", "retired"],
"regime_name": ["working", "working", "retired"],
"wealth": [10.0, 50.0, 30.0],
"health": ["good", "bad", "good"],
"age": [25.0, 25.0, 25.0],
Expand All @@ -31,7 +31,7 @@ result = model.simulate(
)
```

- `"regime"` column is required. Use regime names as strings (e.g., `"working"`).
- `"regime_name"` column is required. Use regime names as strings (e.g., `"working"`).
- Discrete states use string labels from the model's categorical classes (e.g., `"good"`
instead of `0`). Labels are validated and mapped to integer codes automatically.
- Continuous states pass through as-is.
Expand Down
12 changes: 6 additions & 6 deletions docs/user_guide/solving_and_simulating.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ import pandas as pd

df = pd.DataFrame(
{
"regime": ["working_life", "working_life", "retirement", "working_life"],
"regime_name": ["working_life", "working_life", "retirement", "working_life"],
"age": [25.0, 25.0, 25.0, 25.0],
"wealth": [1.0, 5.0, 10.0, 20.0],
"health": ["good", "bad", "bad", "good"], # string labels, auto-converted
Expand Down Expand Up @@ -102,7 +102,7 @@ initial_conditions = {
"age": jnp.array([25.0, 25.0, 25.0, 25.0]),
"wealth": jnp.array([1.0, 5.0, 10.0, 20.0]),
"health": jnp.array([0, 1, 1, 0]), # integer codes for discrete states
"regime": jnp.array(
"regime_id": jnp.array(
[
RegimeId.working_life,
RegimeId.working_life,
Expand All @@ -114,7 +114,7 @@ initial_conditions = {
```

- Every non-shock state must have an entry.
- `"regime"` must be included, with integer codes from the `regime_id_class`.
- `"regime_id"` must be included, with integer codes from the `regime_id_class`.
- All arrays must have the same length (= number of agents).
- Shock states are drawn automatically.

Expand All @@ -140,7 +140,7 @@ Subjects can start at different ages:
initial_conditions = {
"age": jnp.array([40.0, 60.0]),
"wealth": jnp.array([50.0, 50.0]),
"regime": jnp.array(
"regime_id": jnp.array(
[
model.regime_names_to_ids["working_life"],
model.regime_names_to_ids["working_life"],
Expand All @@ -160,7 +160,7 @@ earlier periods are omitted, not filled with placeholders.
df = result.to_dataframe()
```

Returns a pandas DataFrame with columns: `subject_id`, `period`, `age`, `regime`,
Returns a pandas DataFrame with columns: `subject_id`, `period`, `age`, `regime_name`,
`value`, plus all states and actions. Discrete variables are pandas Categorical with
string labels.

Expand Down Expand Up @@ -240,7 +240,7 @@ params = {

# 3. Prepare initial conditions as a DataFrame
initial_df = pd.DataFrame({
"regime": "working_life",
"regime_name": "working_life",
"age": model.ages.values[0],
"wealth": np.linspace(1, 50, 100),
})
Expand Down
Loading
Loading