Problem
Our column type aliases in src/ttsim/typing.py are defined using jaxtyping with jax.Array only:
if TYPE_CHECKING:
from jaxtyping import Array, Bool, Float, Int
BoolColumn: TypeAlias = Bool[Array, " n_obs"]
IntColumn: TypeAlias = Int[Array, " n_obs"]
FloatColumn: TypeAlias = Float[Array, " n_obs"]
Since jaxtyping.Array resolves to jax.Array and numpy.ndarray is not a subclass of jax.Array, static type checkers flag every call site that passes a numpy array where these types are expected.
This creates a split between two type-checking environments:
ty (no jax installed): jaxtyping.Array is jax.Array, which is unresolvable without jax. The column type aliases become Unknown, so ty cannot flag array type mismatches at all. Non-array type checking still works.
ty-jax (jax installed): types resolve to jax.Array → legitimate invalid-argument-type errors everywhere numpy arrays are used, since numpy.ndarray is not a subclass of jax.Array.
Currently we run only ty in CI and have ty-jax commented out. This means we have no static type checking of array types — all other type checking works fine.
Affected areas
src/ttsim/tt/aggregation.py — all overloaded functions (grouped_sum, grouped_count, sum_by_p_id, etc.)
src/ttsim/tt/param_objects.py — PiecewisePolynomialParamValue, ConsecutiveIntLookupTableParamValue
src/ttsim/tt/piecewise_polynomial.py
- All test files that create numpy arrays and pass them to typed functions
src/ttsim/interface_dag_elements/fail_if.py — conditional jax import
Research: current state of array typing
jaxtyping supports union types for multi-backend annotations: Float[np.ndarray | jax.Array, "batch features"] is equivalent to Float[np.ndarray, "batch features"] | Float[jax.Array, "batch features"]. This could work, but requires jax to be importable at type-checking time.
jax.typing.ArrayLike accepts both jax.Array, numpy.ndarray, and Python scalars. However, it's too broad (accepts scalars) and doesn't carry shape/dtype info.
array-api-typing (from the data-apis consortium) aims to provide backend-agnostic array protocols (HasArrayNamespace), but is explicitly pre-release with no published versions yet.
NumPy's position (numpy/numpy#28665, closed as not planned): the array API's simplified typing approach conflicts with NumPy's granular type system. The community is deferring to array-api-typing for cross-library compatibility.
Static type checker support for jaxtyping is partial. ty, mypy, and pyright all treat Float[ArrayType, "shape"] as just ArrayType, ignoring shape/dtype. Protocols and genericised subprotocols have known issues across checkers.
Possible approaches
-
Union column types with conditional import
if TYPE_CHECKING:
import numpy as np
try:
from jaxtyping import Array, Float, Int, Bool
FloatColumn: TypeAlias = Float[np.ndarray | Array, " n_obs"]
IntColumn: TypeAlias = Int[np.ndarray | Array, " n_obs"]
BoolColumn: TypeAlias = Bool[np.ndarray | Array, " n_obs"]
except ImportError:
FloatColumn: TypeAlias = np.ndarray
IntColumn: TypeAlias = np.ndarray
BoolColumn: TypeAlias = np.ndarray
Pro: Preserves jaxtyping shape annotations, works with both backends.
Con: Requires jax importable for the jaxtyping Array import; fallback loses shape info. The conditional import inside TYPE_CHECKING may confuse some checkers.
-
Use numpy.ndarray as the base type, accept jax duck-typing
if TYPE_CHECKING:
from jaxtyping import Float, Int, Bool
import numpy as np
FloatColumn: TypeAlias = Float[np.ndarray, " n_obs"]
IntColumn: TypeAlias = Int[np.ndarray, " n_obs"]
BoolColumn: TypeAlias = Bool[np.ndarray, " n_obs"]
Pro: Works without jax installed; jax arrays duck-type as ndarray in many contexts.
Con: Doesn't reflect that jax arrays are a different type; may miss jax-specific type errors.
-
Wait for array-api-typing
Use the protocol-based approach once array-api-typing is stable.
Pro: Standards-based, future-proof.
Con: No timeline; currently pre-release with no published versions.
-
Define a local ArrayLike protocol
class ArrayLike(Protocol):
@property
def shape(self) -> tuple[int, ...]: ...
@property
def dtype(self) -> Any: ...
Then use Float[ArrayLike, " n_obs"].
Pro: Backend-agnostic, no external dependencies.
Con: May not work with jaxtyping's __class_getitem__; loses concrete type info.
References
Problem
Our column type aliases in
src/ttsim/typing.pyare defined using jaxtyping withjax.Arrayonly:Since
jaxtyping.Arrayresolves tojax.Arrayandnumpy.ndarrayis not a subclass ofjax.Array, static type checkers flag every call site that passes a numpy array where these types are expected.This creates a split between two type-checking environments:
ty(no jax installed):jaxtyping.Arrayisjax.Array, which is unresolvable without jax. The column type aliases becomeUnknown, so ty cannot flag array type mismatches at all. Non-array type checking still works.ty-jax(jax installed): types resolve tojax.Array→ legitimateinvalid-argument-typeerrors everywhere numpy arrays are used, sincenumpy.ndarrayis not a subclass ofjax.Array.Currently we run only
tyin CI and havety-jaxcommented out. This means we have no static type checking of array types — all other type checking works fine.Affected areas
src/ttsim/tt/aggregation.py— all overloaded functions (grouped_sum,grouped_count,sum_by_p_id, etc.)src/ttsim/tt/param_objects.py—PiecewisePolynomialParamValue,ConsecutiveIntLookupTableParamValuesrc/ttsim/tt/piecewise_polynomial.pysrc/ttsim/interface_dag_elements/fail_if.py— conditionaljaximportResearch: current state of array typing
jaxtyping supports union types for multi-backend annotations:
Float[np.ndarray | jax.Array, "batch features"]is equivalent toFloat[np.ndarray, "batch features"] | Float[jax.Array, "batch features"]. This could work, but requires jax to be importable at type-checking time.jax.typing.ArrayLikeaccepts bothjax.Array,numpy.ndarray, and Python scalars. However, it's too broad (accepts scalars) and doesn't carry shape/dtype info.array-api-typing(from the data-apis consortium) aims to provide backend-agnostic array protocols (HasArrayNamespace), but is explicitly pre-release with no published versions yet.NumPy's position (numpy/numpy#28665, closed as not planned): the array API's simplified typing approach conflicts with NumPy's granular type system. The community is deferring to
array-api-typingfor cross-library compatibility.Static type checker support for jaxtyping is partial.
ty,mypy, andpyrightall treatFloat[ArrayType, "shape"]as justArrayType, ignoring shape/dtype. Protocols and genericised subprotocols have known issues across checkers.Possible approaches
Union column types with conditional import
Pro: Preserves jaxtyping shape annotations, works with both backends.
Con: Requires jax importable for the jaxtyping
Arrayimport; fallback loses shape info. The conditional import insideTYPE_CHECKINGmay confuse some checkers.Use
numpy.ndarrayas the base type, accept jax duck-typingPro: Works without jax installed; jax arrays duck-type as ndarray in many contexts.
Con: Doesn't reflect that jax arrays are a different type; may miss jax-specific type errors.
Wait for
array-api-typingUse the protocol-based approach once
array-api-typingis stable.Pro: Standards-based, future-proof.
Con: No timeline; currently pre-release with no published versions.
Define a local
ArrayLikeprotocolThen use
Float[ArrayLike, " n_obs"].Pro: Backend-agnostic, no external dependencies.
Con: May not work with jaxtyping's
__class_getitem__; loses concrete type info.References