Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ releases are available on [Anaconda.org](https://anaconda.org/conda-forge/ttsim)

## Unreleased

- {gh}`130` Split `raw_results.columns` into `raw_results.columns_with_internal_p_ids`
(computable from `processed_data` alone) and `raw_results.columns_with_original_p_ids`
(the reverse-translation of endogenous `p_id_*` columns from {gh}`108`).
({ghuser}`MImmesberger`)
- Fill in the `count_by_p_id`, `mean_by_p_id`, `max_by_p_id`, `min_by_p_id`,
`any_by_p_id`, and `all_by_p_id` aggregations on both the NumPy and JAX backends.
Negative source `p_id` entries are masked out so they cannot influence the result;
Expand Down
4 changes: 4 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ coverage:
project:
default:
target: 85%
ignore:
# Runs on the JAX backend only; the coverage job uses the NumPy backend,
# under which all of its cases are skipped.
- src_mettsim/tests_middle_earth/test_jittability.py
24 changes: 18 additions & 6 deletions src/ttsim/interface_dag_elements/raw_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,36 @@


@interface_function()
def columns(
def columns_with_internal_p_ids(
labels__root_nodes: UnorderedQNames,
processed_data: QNameData,
tt_function: Callable[[QNameData], QNameData],
) -> QNameData:
"""The raw results of the TT function that have been requested as targets.

Arrays are sorted according to the internal sort order. Endogenously- computed
`p_id_*` columns hold internal indices; see `columns_with_remapped_ids` for the
version with user-space `p_id` values.
"""
return tt_function(
{k: v for k, v in processed_data.items() if k in labels__root_nodes},
)


@interface_function()
def columns_with_original_p_ids(
columns_with_internal_p_ids: QNameData,
input_data__flat: FlatData,
input_data__sort_indices: IntColumn,
xnp: ModuleType,
) -> QNameData:
"""The raw results of the TT function that have been requested as targets.
"""Raw results with endogenous `p_id_*` columns remapped to user-space `p_id`s.

Arrays are sorted according to the internal sort order. Endogenously-
computed `p_id_*` columns hold internal indices; those are reverse-
translated to user-space `p_id` values here so consumers downstream see
the original identifiers (with `-1` preserved as the no-link sentinel).
"""
raw = tt_function(
{k: v for k, v in processed_data.items() if k in labels__root_nodes},
)
sorted_orig_p_ids = xnp.asarray(input_data__flat[("p_id",)])[
input_data__sort_indices
]
Expand All @@ -47,7 +59,7 @@ def columns(
)
if _is_endogenous_p_id_pointer(qname)
else value
for qname, value in raw.items()
for qname, value in columns_with_internal_p_ids.items()
}


Expand Down
7 changes: 5 additions & 2 deletions src/ttsim/interface_dag_elements/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@interface_function()
def tree(
raw_results__columns: QNameData,
raw_results__columns_with_original_p_ids: QNameData,
raw_results__params: QNameResults,
raw_results__from_input_data: QNameData,
input_data__sort_indices: IntColumn,
Expand All @@ -43,7 +43,10 @@ def reorder_arrays(v: Any) -> Any: # noqa: ANN401
{
**raw_results__params,
**raw_results__from_input_data,
**{k: reorder_arrays(v) for k, v in raw_results__columns.items()},
**{
k: reorder_arrays(v)
for k, v in raw_results__columns_with_original_p_ids.items()
},
}
)

Expand Down
27 changes: 23 additions & 4 deletions src/ttsim/main_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def top_level_namespace(cls, top_level_namespace: UnorderedQNames) -> Labels:

@dataclass(frozen=True)
class RawResults(MainArg):
columns: QNameData | None = None
columns_with_internal_p_ids: QNameData | None = None
columns_with_original_p_ids: QNameData | None = None
params: QNameData | None = None
from_input_data: QNameData | None = None
combined: QNameData | None = None
Expand All @@ -273,9 +274,27 @@ def __post_init__(self) -> None:
_fix_classmethod_namespace_conflicts(self)

@classmethod
def columns(cls, columns: QNameData) -> RawResults:
"""Column results data."""
return _set_single_field(cls=cls, field_name="columns", field_value=columns)
def columns_with_internal_p_ids(
cls, columns_with_internal_p_ids: QNameData
) -> RawResults:
"""Column results data with endogenous `p_id_*` columns holding internal
indices."""
return _set_single_field(
cls=cls,
field_name="columns_with_internal_p_ids",
field_value=columns_with_internal_p_ids,
)

@classmethod
def columns_with_original_p_ids(
cls, columns_with_original_p_ids: QNameData
) -> RawResults:
"""Column results data with endogenous `p_id_*` columns in user space."""
return _set_single_field(
cls=cls,
field_name="columns_with_original_p_ids",
field_value=columns_with_original_p_ids,
)

@classmethod
def params(cls, params: QNameData) -> RawResults:
Expand Down
3 changes: 2 additions & 1 deletion src/ttsim/main_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class Results(MainTargetABC):

@dataclass(frozen=True)
class RawResults(MainTargetABC):
columns: str = "raw_results__columns"
columns_with_internal_p_ids: str = "raw_results__columns_with_internal_p_ids"
columns_with_original_p_ids: str = "raw_results__columns_with_original_p_ids"
from_input_data: str = "raw_results__from_input_data"
params: str = "raw_results__params"

Expand Down
91 changes: 91 additions & 0 deletions src_mettsim/tests_middle_earth/test_jittability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import datetime
import functools
import inspect
from typing import TYPE_CHECKING, Literal

import dags.tree as dt
import pytest
from dags import get_free_arguments
from mettsim import middle_earth

from ttsim import (
MainTarget,
OrigPolicyObjects,
SpecializedEnvironment,
TTTargets,
main,
)
from ttsim.tt import ColumnFunction

if TYPE_CHECKING:
from ttsim.typing import SpecEnvWithPartialledParamsAndScalars


def get_orig_mettsim_column_functions() -> list[tuple[tuple[str, ...], ColumnFunction]]:
orig = main(
main_target=MainTarget.orig_policy_objects.column_objects_and_param_functions,
orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
)
return [(tp, cf) for tp, cf in orig.items() if isinstance(cf, ColumnFunction)]


@functools.lru_cache(maxsize=100)
def cached_specialized_environment(
policy_date: datetime.date,
backend: Literal["numpy", "jax"],
) -> SpecEnvWithPartialledParamsAndScalars:
return main(
main_target=(
"specialized_environment_for_plotting_and_templates",
"with_partialled_params_and_scalars",
),
policy_date=policy_date,
orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
backend=backend,
include_fail_nodes=False,
include_warn_nodes=False,
)


@pytest.mark.skipif_numpy
@pytest.mark.parametrize(
("tree_path", "fun"),
get_orig_mettsim_column_functions(),
ids=[str(x[0]) for x in get_orig_mettsim_column_functions()],
)
def test_jittable(tree_path, fun, backend, xnp):
policy_date = min(fun.end_date, datetime.date.today()) # noqa: DTZ011
qname = dt.qname_from_tree_path((*tree_path[:-2], fun.leaf_name))
env = {
qname: cached_specialized_environment(policy_date=policy_date, backend=backend)[
qname
]
}

processed_data = {}
for arg_name in get_free_arguments(env[qname]):
arg = inspect.signature(env[qname]).parameters[arg_name]
if "FloatColumn" in arg.annotation:
processed_data[arg_name] = xnp.zeros(1, dtype=float)
elif "IntColumn" in arg.annotation:
processed_data[arg_name] = xnp.zeros(1, dtype=int)
elif "BoolColumn" in arg.annotation:
processed_data[arg_name] = xnp.zeros(1, dtype=bool)
else:
raise ValueError(f"Unknown column type: {arg.annotation}")

if not fun.fail_msg_if_included:
main(
main_target=("raw_results", "columns_with_internal_p_ids"),
policy_date=policy_date,
specialized_environment=SpecializedEnvironment.with_partialled_params_and_scalars(
env
),
processed_data=processed_data,
tt_targets=TTTargets.qname([qname]),
backend=backend,
include_fail_nodes=False,
include_warn_nodes=False,
)
2 changes: 1 addition & 1 deletion tests/interface_dag_elements/test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,7 +1641,7 @@ def test_fail_if_name_of_last_branch_element_is_not_the_functions_leaf_name(
"main_target",
[
MainTarget.tt_function,
MainTarget.raw_results.columns,
MainTarget.raw_results.columns_with_internal_p_ids,
],
)
def test_raise_tt_root_nodes_are_missing_without_input_data(
Expand Down
50 changes: 24 additions & 26 deletions tests/interface_dag_elements/test_raw_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,17 @@
# these objects — to assert on the pristine instances.
_IFACE_DIR = Path(ttsim.interface_dag_elements.__file__).parent
_raw_results = load_module(path=_IFACE_DIR / "raw_results.py", root=_IFACE_DIR)
columns = _raw_results.columns
columns_with_internal_p_ids = _raw_results.columns_with_internal_p_ids
columns_with_original_p_ids = _raw_results.columns_with_original_p_ids
from_input_data = _raw_results.from_input_data
params = _raw_results.params


# =============================================================================
# columns() function tests
# columns_with_internal_p_ids() function tests
# =============================================================================
def test_columns_is_interface_function():
assert isinstance(columns, InterfaceFunction)


def _identity_p_id_inputs(xnp, n: int) -> dict:
"""Minimal `input_data__flat` + sort indices for tests that don't care
about row order — the original `p_id` array is already sorted, so the
sort is the identity permutation.
"""
return {
"input_data__flat": {("p_id",): xnp.arange(n)},
"input_data__sort_indices": xnp.arange(n),
"xnp": xnp,
}
def test_columns_with_internal_p_ids_is_interface_function():
assert isinstance(columns_with_internal_p_ids, InterfaceFunction)


def test_columns_filters_to_root_nodes(xnp):
Expand All @@ -48,11 +37,10 @@ def test_columns_filters_to_root_nodes(xnp):
def tt_function(data):
return data

result = columns(
result = columns_with_internal_p_ids(
labels__root_nodes=root_nodes,
processed_data=processed_data,
tt_function=tt_function,
**_identity_p_id_inputs(xnp, n=3),
)

# Only root_nodes should be passed to tt_function
Expand All @@ -74,11 +62,10 @@ def tt_function(data):
call_args.append(data)
return {"output": xnp.array([1, 2])}

columns(
columns_with_internal_p_ids(
labels__root_nodes=root_nodes,
processed_data=processed_data,
tt_function=tt_function,
**_identity_p_id_inputs(xnp, n=2),
)

# Verify tt_function was called with only root_nodes data
Expand All @@ -96,11 +83,10 @@ def test_columns_returns_tt_function_output(xnp):
def tt_function(_data):
return expected_output

result = columns(
result = columns_with_internal_p_ids(
labels__root_nodes=root_nodes,
processed_data=processed_data,
tt_function=tt_function,
**_identity_p_id_inputs(xnp, n=2),
)

assert result == expected_output
Expand All @@ -113,16 +99,22 @@ def test_columns_with_empty_root_nodes(xnp):
def tt_function(data):
return data

result = columns(
result = columns_with_internal_p_ids(
labels__root_nodes=root_nodes,
processed_data=processed_data,
tt_function=tt_function,
**_identity_p_id_inputs(xnp, n=2),
)

assert result == {}


# =============================================================================
# columns_with_original_p_ids() function tests
# =============================================================================
def test_columns_with_original_p_ids_is_interface_function():
assert isinstance(columns_with_original_p_ids, InterfaceFunction)


# =============================================================================
# from_input_data() function tests
# =============================================================================
Expand Down Expand Up @@ -301,11 +293,17 @@ def test_params_returns_various_value_types():
# =============================================================================
# Dependencies property tests
# =============================================================================
def test_columns_dependencies():
assert columns.dependencies == {
def test_columns_with_internal_p_ids_dependencies():
assert columns_with_internal_p_ids.dependencies == {
"labels__root_nodes",
"processed_data",
"tt_function",
}


def test_columns_with_original_p_ids_dependencies():
Comment thread
hmgaudecker marked this conversation as resolved.
assert columns_with_original_p_ids.dependencies == {
"columns_with_internal_p_ids",
"input_data__flat",
"input_data__sort_indices",
"xnp",
Expand Down
2 changes: 1 addition & 1 deletion tests/interface_dag_elements/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_restore_original_row_order(
):
"""Test that the tree function restores original row order correctly."""
result = tree(
raw_results__columns=raw_results_columns,
raw_results__columns_with_original_p_ids=raw_results_columns,
raw_results__params=raw_results_params,
raw_results__from_input_data=raw_results_from_input_data,
input_data__sort_indices=input_data__sort_indices,
Expand Down
4 changes: 2 additions & 2 deletions tests/interface_dag_elements/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_warn_if_evaluation_date_set_in_multiple_places_implicitly_added(backend
}
with pytest.warns(match="You have specified the evaluation date in more than one"):
main(
main_target=MainTarget.raw_results.columns,
main_target=MainTarget.raw_results.columns_with_internal_p_ids,
policy_environment=policy_environment,
evaluation_date=datetime.date(2025, 1, 1),
processed_data={"p_id": xnp.array([0])},
Expand All @@ -168,7 +168,7 @@ def test_do_not_need_to_warn_if_evaluation_date_is_set_only_once(backend, xnp):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
main(
main_target=MainTarget.raw_results.columns,
main_target=MainTarget.raw_results.columns_with_internal_p_ids,
policy_environment=policy_environment,
evaluation_date=datetime.date(2025, 1, 1),
processed_data={"p_id": xnp.array([0])},
Expand Down
Loading
Loading