Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 124 additions & 2 deletions pixi.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
]
dynamic = [ "version" ]
dependencies = [
"beartype>=0.22",
"dags>=0.5.1",
"jax>=0.9",
"jaxopt>=0.8.5",
Expand Down
33 changes: 21 additions & 12 deletions src/skillmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
"""Skillmodels: A Python package for estimating latent factor models."""

# Enable 64-bit JAX before any skillmodels submodule -- and crucially before
# any transitive `import jaxopt` -- so jaxopt's module-level jit/sort
# kernels see int64 as the default integer type. Without this, jaxopt's
# `argsort` inside `LBFGSB.update` emits an `s32` accumulator into an
# `s64` scatter operand and XLA's permutation_sort_simplifier verifier
# rejects it on JAX >= 0.10 / cuda13. The package has always assumed
# x64 (every CHS / AF / AMN entry point sets it inside the function);
# centralising it at import time fixes the jaxopt path too and is a
# no-op for callers who already enable it.
# Enable 64-bit JAX before any skillmodels submodule. Every CHS / AF / AMN
# entry point already sets this inside its function body; centralising it
# here makes the package behave consistently for direct callers.
import os

os.environ.setdefault("JAX_ENABLE_X64", "1")

import contextlib

import jax
# Workaround for a JAX 0.10 XLA bug surfaced by jaxopt's `LBFGSB.update`.
# The `permutation_sort_simplifier` HLO pass mis-lowers the `argsort`
# inside `update`: it emits an s32 reduction accumulator into the s64
# scatter operand built by the rest of the optimizer, and the HLO
# verifier rejects the resulting mismatch with `INVALID_ARGUMENT:
# Reduction function's accumulator shape at index 0 differs from the
# init_value shape: s32[] vs s64[]`. Disabling just that one pass via
# `XLA_FLAGS` keeps every other XLA optimisation intact and is a no-op
# on JAX < 0.10 (pre-0.10 lacks the pass). Must be set *before* `import
# jax` because XLA reads `XLA_FLAGS` once at backend init.
_xla_pass_disable = "--xla_disable_hlo_passes=permutation_sort_simplifier" # noqa: S105
_existing_xla_flags = os.environ.get("XLA_FLAGS", "")
if _xla_pass_disable not in _existing_xla_flags:
os.environ["XLA_FLAGS"] = f"{_existing_xla_flags} {_xla_pass_disable}".strip()

import contextlib # noqa: E402

import jax # noqa: E402

jax.config.update("jax_enable_x64", True) # noqa: FBT003

Expand Down
87 changes: 87 additions & 0 deletions src/skillmodels/_beartype_conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Per-exception `BeartypeConf` instances used at the skillmodels perimeter.

Decorators at user-facing entry points configure beartype to raise the
existing project exception class on parameter-type violations,
preserving the documented exception hierarchy in
`skillmodels.exceptions`.

The constructors and call sites decorated through this module are the
"perimeter": ModelSpec / FactorSpec / AnchoringSpec / Normalizations,
the three estimation-options dataclasses, and every public function
exposed from the top-level package or the subpackage `__init__`s. The
internal helpers below the perimeter are unannotated for beartype and
trust the perimeter to have already validated parameter types.
"""

from collections.abc import Callable

from beartype import BeartypeConf, BeartypeStrategy, beartype

from skillmodels.exceptions import (
DiagnosticsCallError,
EstimationCallError,
InferenceCallError,
ModelSpecInitializationError,
OptionsInitializationError,
SimulationCallError,
)


def _conf(exc: type[Exception]) -> BeartypeConf:
"""Build a `BeartypeConf` that raises `exc` on parameter-type violations.

`On` strategy: full O(n) container validation so every bad entry in
a mapping/sequence is reported, not just one sampled element. The
decorated entry points are called rarely (construction, estimate,
simulate, plot), so per-call cost is invisible compared to the
JIT-compiled hot path each one kicks off.

`is_pep484_tower=True`: respect the PEP-484 numeric tower so `int`
satisfies `float`-typed parameters (matches the implicit numeric
conversion that Python and ruff's PYI041 both assume).
"""
return BeartypeConf(
violation_param_type=exc,
strategy=BeartypeStrategy.On,
is_pep484_tower=True,
)


def beartype_init(conf: BeartypeConf) -> Callable[[type], type]:
"""Class decorator that wraps only `__init__` with `@beartype(conf=conf)`.

Bare `@beartype` on a class wraps every method, which surfaces
non-public annotation drift on instance methods that has nothing
to do with parameter validation at construction time (e.g. a
helper method that takes a JAX array typed loosely as `Any`). The
only annotations we actively curate at the perimeter are the
public-facing `__init__` parameters; restrict to those.
"""

def wrap(cls: type) -> type:
cls.__init__ = beartype(conf=conf)(cls.__init__) # ty: ignore[invalid-assignment]
return cls

return wrap


# Construction of the four user-facing model-spec dataclasses.
MODEL_SPEC_CONF = _conf(ModelSpecInitializationError)

# Construction of CHSEstimationOptions, AFEstimationOptions,
# AMNEstimationOptions.
OPTIONS_CONF = _conf(OptionsInitializationError)

# `get_maximization_inputs`, `get_filtered_states`, `estimate_af`,
# `estimate_amn`, `get_af_posterior_states`,
# `get_amn_posterior_states`.
ESTIMATION_CONF = _conf(EstimationCallError)

# `compute_af_standard_errors`, `compute_amn_standard_errors`.
INFERENCE_CONF = _conf(InferenceCallError)

# `simulate_dataset`, `simulate_policy_effect`.
SIMULATION_CONF = _conf(SimulationCallError)

# Diagnostics + visualisation entry points.
DIAGNOSTICS_CONF = _conf(DiagnosticsCallError)
3 changes: 3 additions & 0 deletions src/skillmodels/af/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import numpy as np
import optimagic as om
import pandas as pd
from beartype import beartype
from jax import Array

from skillmodels._beartype_conf import ESTIMATION_CONF
from skillmodels.af.initial_period import estimate_initial_period
from skillmodels.af.params import get_measurements_per_factor
from skillmodels.af.transition_period import estimate_transition_period
Expand All @@ -29,6 +31,7 @@
from skillmodels.common.process_model import process_model


@beartype(conf=ESTIMATION_CONF)
def estimate_af( # noqa: PLR0915
model_spec: ModelSpec,
data: pd.DataFrame,
Expand Down
3 changes: 3 additions & 0 deletions src/skillmodels/af/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from beartype import beartype
from jax import Array

from skillmodels._beartype_conf import INFERENCE_CONF
from skillmodels.af.batching import auto_n_obs_per_batch
from skillmodels.af.estimate import _extract_period_data
from skillmodels.af.halton import create_halton_nodes_and_weights
Expand Down Expand Up @@ -119,6 +121,7 @@ class AFInferenceResult:
"""Number of bootstrap replicates drawn."""


@beartype(conf=INFERENCE_CONF)
def compute_af_standard_errors(
result: AFEstimationResult,
data: pd.DataFrame,
Expand Down
82 changes: 65 additions & 17 deletions src/skillmodels/af/jaxopt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@

import os

# Ensure x64 is on *before* `from jaxopt import LBFGSB` -- jaxopt's
# module-level jit kernels resolve the default integer dtype at import
# time. With x64 off, `jnp.argsort` inside `LBFGSB.update` emits int32
# indices that scatter into the int64 operand the rest of the optimizer
# builds, and XLA's permutation_sort_simplifier verifier rejects the
# resulting mismatch on JAX >= 0.10. `skillmodels/__init__.py` sets the
# same flag at package import; this is a belt-and-suspenders guard for
# callers that import this module directly.
# Belt-and-suspenders for callers that import this module directly without
# going through `skillmodels/__init__.py`. Two things must be set before
# `import jax` / `from jaxopt import LBFGSB`:
#
# 1. `JAX_ENABLE_X64=1` — the AF pipeline assumes float64 throughout.
# 2. `XLA_FLAGS=--xla_disable_hlo_passes=permutation_sort_simplifier` —
# works around a JAX 0.10 bug where the `argsort` inside
# `LBFGSB.update` emits an s32 reduction accumulator into an s64
# scatter operand, and XLA's `permutation_sort_simplifier` pass
# rejects the mismatch. See `skillmodels/__init__.py` for the full
# explanation.
os.environ.setdefault("JAX_ENABLE_X64", "1")

import jax
_xla_pass_disable = "--xla_disable_hlo_passes=permutation_sort_simplifier" # noqa: S105
_existing_xla_flags = os.environ.get("XLA_FLAGS", "")
if _xla_pass_disable not in _existing_xla_flags:
os.environ["XLA_FLAGS"] = f"{_existing_xla_flags} {_xla_pass_disable}".strip()

import jax # noqa: E402

jax.config.update("jax_enable_x64", True) # noqa: FBT003

Expand Down Expand Up @@ -136,24 +144,64 @@ def objective_and_grad(free_vec: Array) -> tuple[Array, Array]:
val, grad = loglike_and_grad(full_vec)
return val, grad[free_idx]

# Match scipy_lbfgsb's stopping rule: stop when EITHER
# * max|projected_grad| < gtol_abs ("gtol channel"), OR
# * (f_k - f_{k+1}) / max(|f_k|, |f_{k+1}|, 1) < ftol_rel
# ("ftol channel"; this is the criterion that typically fires in
# practice for skill-formation likelihoods that go locally flat
# before the gradient does).
# Accept the canonical scipy keys so the same `optimizer_options`
# dict works for both backends; fall back to historical jaxopt
# names for compatibility.
gtol_abs = float(options.pop("convergence_gtol_abs", options.pop("tol", 1e-5)))
ftol_rel = float(options.pop("convergence_ftol_rel", 2.22e-9))
maxiter = int(options.pop("stopping_maxiter", options.pop("maxiter", 15_000)))
history_size = int(options.pop("history_size", 10))

solver = LBFGSB(
fun=objective_and_grad,
value_and_grad=True,
maxiter=int(options.pop("maxiter", 500)),
tol=float(options.pop("tol", 1e-6)),
history_size=int(options.pop("history_size", 10)),
# `maxiter` here is jaxopt's *internal* fail-safe cap; the outer
# Python loop below drives stopping. Set huge so jaxopt never
# interrupts us mid-iteration.
maxiter=maxiter,
tol=gtol_abs,
history_size=history_size,
**options,
)
opt_step = solver.run(free_initial, bounds=(free_lower, free_upper))

final_full = full_template.at[free_idx].set(opt_step.params) # noqa: PD008
bounds = (free_lower, free_upper)
state = solver.init_state(free_initial, bounds=bounds)
params = free_initial
prev_val = jnp.inf
stopped_on = "maxiter"
n_iter = 0
# fallback if `maxiter == 0` and the loop body never executes.
for n_iter in range(1, maxiter + 1): # noqa: B007
params, state = solver.update(params, state, bounds=bounds)
cur_val = state.value
# gtol channel
if bool(state.error < gtol_abs):
stopped_on = "gtol"
break
# ftol channel (skip first iteration where prev_val == inf)
denom = jnp.maximum(
jnp.maximum(jnp.abs(prev_val), jnp.abs(cur_val)),
1.0,
)
rel_drop = jnp.abs(prev_val - cur_val) / denom
if bool(jnp.isfinite(prev_val)) and bool(rel_drop < ftol_rel):
stopped_on = "ftol"
break
prev_val = cur_val

final_full = full_template.at[free_idx].set(params) # noqa: PD008
result_df = full_params_df.copy()
result_df["value"] = np.asarray(jax.device_get(final_full))

n_iter = int(opt_step.state.iter_num)
return JaxoptResult(
params=result_df,
fun=float(jax.device_get(opt_step.state.value)),
success=n_iter < solver.maxiter,
fun=float(jax.device_get(state.value)),
success=stopped_on != "maxiter",
n_iter=n_iter,
)
27 changes: 14 additions & 13 deletions src/skillmodels/af/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
"""

import functools
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
from jax import Array

from skillmodels.af.types import ChainLink
Expand Down Expand Up @@ -230,7 +231,7 @@ def _parse_initial_params(

def _map_over_obs(
f: Callable,
*xs: Array,
*xs: Array | np.ndarray,
n_obs_per_batch: int | None,
) -> Array:
"""Map ``f`` over the leading axis of ``xs``, optionally in batches.
Expand Down Expand Up @@ -557,7 +558,7 @@ def af_per_obs_loglike_transition(
prev_control_params: Array,
prev_loadings_flat: Array,
prev_meas_sds: Array,
prev_distribution: dict[str, Array],
prev_distribution: Mapping[str, Array | np.ndarray],
chain_links: tuple[ChainLink, ...],
obs_factor_values_chain: Array,
joint_nodes: Array,
Expand Down Expand Up @@ -660,7 +661,7 @@ def af_loglike_transition(
prev_control_params: Array,
prev_loadings_flat: Array,
prev_meas_sds: Array,
prev_distribution: dict[str, Array],
prev_distribution: Mapping[str, Array | np.ndarray],
chain_links: tuple[ChainLink, ...],
obs_factor_values_chain: Array,
joint_nodes: Array,
Expand Down Expand Up @@ -868,7 +869,7 @@ def _transition_loglike_per_obs(
prev_meas_mask: Array,
prev_full_loadings: Array,
prev_meas_sds: Array,
prev_distribution: dict[str, Array],
prev_distribution: Mapping[str, Array | np.ndarray],
chain_links: tuple[ChainLink, ...],
obs_factor_values_chain: Array,
joint_nodes: Array,
Expand Down Expand Up @@ -967,8 +968,8 @@ def _single_obs(
def _compute_investment(
theta_prev: Array,
obs_factor_values: Array,
inv_eq_params: Array,
inv_sds: Array,
inv_eq_params: Array | np.ndarray,
inv_sds: Array | np.ndarray,
eps_i: Array,
n_endogenous_factors: int,
n_state_factors: int,
Expand Down Expand Up @@ -1001,8 +1002,8 @@ def _rebuild_chain_at_period(
z_state: Array,
z_inv_per_step: Array,
z_shock_per_step: Array,
initial_mean: Array,
initial_chol: Array,
initial_mean: Array | np.ndarray,
initial_chol: Array | np.ndarray,
chain_links: tuple[ChainLink, ...],
obs_factor_values_at_obs_per_step: Array,
n_state_factors: int,
Expand Down Expand Up @@ -1043,7 +1044,7 @@ def _rebuild_chain_at_period(
step), shape (n_state_factors,). When `chain_links` is empty,
returns the period-0 state directly.
"""
theta = initial_mean + initial_chol @ z_state
theta = jnp.asarray(initial_mean + initial_chol @ z_state)
for step_idx, link in enumerate(chain_links):
z_inv = z_inv_per_step[step_idx]
z_shock = z_shock_per_step[step_idx]
Expand Down Expand Up @@ -1080,9 +1081,9 @@ def _integrate_transition_single_obs(
prev_meas_mask: Array,
prev_full_loadings: Array,
prev_meas_sds: Array,
obs_cond_weights: Array,
obs_cond_means: Array,
cond_chols: Array,
obs_cond_weights: Array | np.ndarray,
obs_cond_means: Array | np.ndarray,
cond_chols: Array | np.ndarray,
chain_links: tuple[ChainLink, ...],
obs_factor_values_chain: Array,
joint_nodes: Array,
Expand Down
3 changes: 3 additions & 0 deletions src/skillmodels/af/posterior_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from beartype import beartype
from jax import Array

from skillmodels._beartype_conf import ESTIMATION_CONF
from skillmodels.af.halton import create_halton_nodes_and_weights
from skillmodels.af.initial_period import _build_loading_mask, _get_ordered_measures
from skillmodels.af.likelihood import _log_normal_pdf
Expand All @@ -21,6 +23,7 @@
from skillmodels.common.state_ranges import create_state_ranges


@beartype(conf=ESTIMATION_CONF)
def get_af_posterior_states(
af_result: AFEstimationResult,
model_spec: ModelSpec,
Expand Down
Loading
Loading