Skip to content

Support static type checking for both numpy and jax array types #87

@hmgaudecker

Description

@hmgaudecker

Problem

Our column type aliases in src/ttsim/typing.py are defined using jaxtyping with jax.Array only:

if TYPE_CHECKING:
    from jaxtyping import Array, Bool, Float, Int

    BoolColumn: TypeAlias = Bool[Array, " n_obs"]
    IntColumn: TypeAlias = Int[Array, " n_obs"]
    FloatColumn: TypeAlias = Float[Array, " n_obs"]

Since jaxtyping.Array resolves to jax.Array and numpy.ndarray is not a subclass of jax.Array, static type checkers flag every call site that passes a numpy array where these types are expected.

This creates a split between two type-checking environments:

  • ty (no jax installed): jaxtyping.Array is jax.Array, which is unresolvable without jax. The column type aliases become Unknown, so ty cannot flag array type mismatches at all. Non-array type checking still works.
  • ty-jax (jax installed): types resolve to jax.Array → legitimate invalid-argument-type errors everywhere numpy arrays are used, since numpy.ndarray is not a subclass of jax.Array.

Currently we run only ty in CI and have ty-jax commented out. This means we have no static type checking of array types — all other type checking works fine.

Affected areas

  • src/ttsim/tt/aggregation.py — all overloaded functions (grouped_sum, grouped_count, sum_by_p_id, etc.)
  • src/ttsim/tt/param_objects.pyPiecewisePolynomialParamValue, ConsecutiveIntLookupTableParamValue
  • src/ttsim/tt/piecewise_polynomial.py
  • All test files that create numpy arrays and pass them to typed functions
  • src/ttsim/interface_dag_elements/fail_if.py — conditional jax import

Research: current state of array typing

jaxtyping supports union types for multi-backend annotations: Float[np.ndarray | jax.Array, "batch features"] is equivalent to Float[np.ndarray, "batch features"] | Float[jax.Array, "batch features"]. This could work, but requires jax to be importable at type-checking time.

jax.typing.ArrayLike accepts both jax.Array, numpy.ndarray, and Python scalars. However, it's too broad (accepts scalars) and doesn't carry shape/dtype info.

array-api-typing (from the data-apis consortium) aims to provide backend-agnostic array protocols (HasArrayNamespace), but is explicitly pre-release with no published versions yet.

NumPy's position (numpy/numpy#28665, closed as not planned): the array API's simplified typing approach conflicts with NumPy's granular type system. The community is deferring to array-api-typing for cross-library compatibility.

Static type checker support for jaxtyping is partial. ty, mypy, and pyright all treat Float[ArrayType, "shape"] as just ArrayType, ignoring shape/dtype. Protocols and genericised subprotocols have known issues across checkers.

Possible approaches

  1. Union column types with conditional import

    if TYPE_CHECKING:
        import numpy as np
        try:
            from jaxtyping import Array, Float, Int, Bool
            FloatColumn: TypeAlias = Float[np.ndarray | Array, " n_obs"]
            IntColumn: TypeAlias = Int[np.ndarray | Array, " n_obs"]
            BoolColumn: TypeAlias = Bool[np.ndarray | Array, " n_obs"]
        except ImportError:
            FloatColumn: TypeAlias = np.ndarray
            IntColumn: TypeAlias = np.ndarray
            BoolColumn: TypeAlias = np.ndarray

    Pro: Preserves jaxtyping shape annotations, works with both backends.
    Con: Requires jax importable for the jaxtyping Array import; fallback loses shape info. The conditional import inside TYPE_CHECKING may confuse some checkers.

  2. Use numpy.ndarray as the base type, accept jax duck-typing

    if TYPE_CHECKING:
        from jaxtyping import Float, Int, Bool
        import numpy as np
        FloatColumn: TypeAlias = Float[np.ndarray, " n_obs"]
        IntColumn: TypeAlias = Int[np.ndarray, " n_obs"]
        BoolColumn: TypeAlias = Bool[np.ndarray, " n_obs"]

    Pro: Works without jax installed; jax arrays duck-type as ndarray in many contexts.
    Con: Doesn't reflect that jax arrays are a different type; may miss jax-specific type errors.

  3. Wait for array-api-typing
    Use the protocol-based approach once array-api-typing is stable.
    Pro: Standards-based, future-proof.
    Con: No timeline; currently pre-release with no published versions.

  4. Define a local ArrayLike protocol

    class ArrayLike(Protocol):
        @property
        def shape(self) -> tuple[int, ...]: ...
        @property
        def dtype(self) -> Any: ...

    Then use Float[ArrayLike, " n_obs"].
    Pro: Backend-agnostic, no external dependencies.
    Con: May not work with jaxtyping's __class_getitem__; loses concrete type info.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions