From a6860ded72297978a6a239aa5cba9cca62834bfe Mon Sep 17 00:00:00 2001 From: Laurent Courty Date: Wed, 25 Mar 2026 15:09:11 -0600 Subject: [PATCH 1/2] add extra hotstart tests --- tests/core/test_hotstart_integration.py | 213 ++++++++++++++++++++-- tests/core/test_hotstart_state_loading.py | 147 +++++++++++---- 2 files changed, 307 insertions(+), 53 deletions(-) diff --git a/tests/core/test_hotstart_integration.py b/tests/core/test_hotstart_integration.py index a282b6c..aae67c1 100644 --- a/tests/core/test_hotstart_integration.py +++ b/tests/core/test_hotstart_integration.py @@ -70,6 +70,7 @@ def build_simulation( sim_config: SimulationConfig, domain_5by5, hotstart_bytes: bytes | None = None, + raster_output_provider=None, ) -> Simulation: """Build a simulation with optional hotstart. @@ -78,7 +79,9 @@ def build_simulation( domain_5by5: Domain fixture hotstart_bytes: Optional hotstart archive bytes """ - raster_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) + raster_output = raster_output_provider or MemoryRasterOutputProvider( + {"out_map_names": sim_config.output_map_names} + ) builder = ( SimulationBuilder(sim_config, domain_5by5.arr_mask, np.float32) @@ -116,6 +119,41 @@ def build_simulation( return simulation +def run_to_split_and_create_hotstart( + sim_config: SimulationConfig, + domain_5by5, + split_time: datetime, + raster_output_provider=None, +) -> tuple[Simulation, bytes]: + simulation = build_simulation( + sim_config, + domain_5by5, + raster_output_provider=raster_output_provider, + ) + simulation.initialize() + while simulation.sim_time < split_time: + simulation.update() + return simulation, simulation.create_hotstart().getvalue() + + +def assert_final_state_matches(simulation: Simulation, reference: Simulation) -> None: + for key in ["water_depth", "qe", "qs"]: + arr_resumed = simulation.raster_domain.get_array(key) + arr_reference = reference.raster_domain.get_array(key) + np.testing.assert_allclose(arr_resumed, arr_reference, err_msg=f"Final {key} mismatch") + + +def get_unique_record_times( + output_provider: MemoryRasterOutputProvider, + output_key: str = "water_depth", +) -> list[datetime | timedelta]: + unique_times: list[datetime | timedelta] = [] + for sim_time, _ in output_provider.output_maps_dict[output_key]: + if not unique_times or sim_time != unique_times[-1]: + unique_times.append(sim_time) + return unique_times + + @pytest.fixture(scope="module") def base_time() -> datetime: """Base time for all simulations.""" @@ -166,15 +204,11 @@ def test_roundtrip_state_restoration_and_match( end_time=end_time, helpers=helpers, ) - sim_a = build_simulation(sim_a_config, domain_5by5) - sim_a.initialize() - - # Run until split time - while sim_a.sim_time < split_time: - sim_a.update() - - # Create hotstart at split point - hotstart_bytes = sim_a.create_hotstart() + sim_a, hotstart_bytes = run_to_split_and_create_hotstart( + sim_a_config, + domain_5by5, + split_time, + ) # Save state for comparison saved_sim_time = sim_a.sim_time @@ -195,7 +229,7 @@ def test_roundtrip_state_restoration_and_match( end_time=end_time, helpers=helpers, ) - sim_b = build_simulation(sim_b_config, domain_5by5, hotstart_bytes=hotstart_bytes.getvalue()) + sim_b = build_simulation(sim_b_config, domain_5by5, hotstart_bytes=hotstart_bytes) # Verify raster state was restored for key in saved_raster_state: @@ -249,16 +283,157 @@ def test_roundtrip_state_restoration_and_match( # Step 4: Verify final results match uninterrupted reference # Use qe/qs (internal flow arrays) instead of qx/qy (output arrays computed on-the-fly) - for key in ["water_depth", "qe", "qs"]: - arr_resumed = sim_b.raster_domain.get_array(key) - arr_uninterrupted = uninterrupted_simulation.raster_domain.get_array(key) - np.testing.assert_allclose( - arr_resumed, - arr_uninterrupted, - err_msg=f"Final {key} mismatch for split_time={split_seconds}s", - ) + assert_final_state_matches(sim_b, uninterrupted_simulation) # Verify simulation reached end time assert sim_b.sim_time == end_time, ( f"Resumed simulation did not reach end time: {sim_b.sim_time} != {end_time}" ) + + +def test_resume_allows_output_provider_change( + domain_5by5, + helpers, + base_time, + uninterrupted_simulation: Simulation, +) -> None: + split_time = base_time + timedelta(seconds=30) + end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) + + sim_config = create_sim_config( + start_time=base_time, + end_time=end_time, + helpers=helpers, + ) + initial_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) + _, hotstart_bytes = run_to_split_and_create_hotstart( + sim_config, + domain_5by5, + split_time, + raster_output_provider=initial_output, + ) + + resumed_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) + sim_b = build_simulation( + sim_config, + domain_5by5, + hotstart_bytes=hotstart_bytes, + raster_output_provider=resumed_output, + ) + + run_simulation_to_end(sim_b, skip_initialize=True) + + assert initial_output is not resumed_output + assert resumed_output.output_maps_dict["water_depth"] + assert len(resumed_output.output_maps_dict["water_depth"]) >= 1 + assert_final_state_matches(sim_b, uninterrupted_simulation) + assert sim_b.sim_time == end_time + + +def test_resume_allows_output_map_name_change( + domain_5by5, + helpers, + base_time, + uninterrupted_simulation: Simulation, +) -> None: + split_time = base_time + timedelta(seconds=30) + end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) + + sim_a_config = create_sim_config( + start_time=base_time, + end_time=end_time, + helpers=helpers, + ) + _, hotstart_bytes = run_to_split_and_create_hotstart(sim_a_config, domain_5by5, split_time) + + resumed_output_map_names = helpers.make_output_map_names( + "out_resume", + ["water_depth", "qx", "qy", "volume_error"], + ) + sim_b_config = sim_a_config.model_copy(update={"output_map_names": resumed_output_map_names}) + resumed_output = MemoryRasterOutputProvider({"out_map_names": resumed_output_map_names}) + sim_b = build_simulation( + sim_b_config, + domain_5by5, + hotstart_bytes=hotstart_bytes, + raster_output_provider=resumed_output, + ) + + run_simulation_to_end(sim_b, skip_initialize=True) + + assert sim_b.report.out_map_names == resumed_output_map_names + assert resumed_output.out_map_names == resumed_output_map_names + assert resumed_output.output_maps_dict["water_depth"] + assert sim_a_config.output_map_names["water_depth"] != resumed_output_map_names["water_depth"] + assert_final_state_matches(sim_b, uninterrupted_simulation) + + +def test_resume_allows_end_time_extension(domain_5by5, helpers, base_time) -> None: + split_time = base_time + timedelta(seconds=30) + original_end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) + extended_end_time = base_time + timedelta(seconds=90) + + sim_a_config = create_sim_config( + start_time=base_time, + end_time=original_end_time, + helpers=helpers, + ) + _, hotstart_bytes = run_to_split_and_create_hotstart(sim_a_config, domain_5by5, split_time) + + sim_b_config = create_sim_config( + start_time=base_time, + end_time=extended_end_time, + helpers=helpers, + ) + sim_b = build_simulation(sim_b_config, domain_5by5, hotstart_bytes=hotstart_bytes) + + reference = build_simulation(sim_b_config, domain_5by5) + run_simulation_to_end(reference) + run_simulation_to_end(sim_b, skip_initialize=True) + + assert sim_b.sim_time == extended_end_time + assert_final_state_matches(sim_b, reference) + + +def test_resume_applies_new_record_step_cadence(domain_5by5, helpers, base_time) -> None: + split_time = base_time + timedelta(seconds=34.2) + end_time = base_time + timedelta(seconds=70) + original_config = create_sim_config( + start_time=base_time, + end_time=end_time, + helpers=helpers, + ) + sim_a, hotstart_bytes = run_to_split_and_create_hotstart( + original_config, + domain_5by5, + split_time, + ) + + resumed_record_step = timedelta(seconds=10) + resumed_config = original_config.model_copy(update={"record_step": resumed_record_step}) + resumed_output = MemoryRasterOutputProvider({"out_map_names": resumed_config.output_map_names}) + sim_b = build_simulation( + resumed_config, + domain_5by5, + hotstart_bytes=hotstart_bytes, + raster_output_provider=resumed_output, + ) + + run_simulation_to_end(sim_b, skip_initialize=True) + + resumed_record_times = get_unique_record_times(resumed_output) + resumed_offset = sim_a.sim_time - base_time + expected_record_times = [ + resumed_offset + resumed_record_step, + resumed_offset + (2 * resumed_record_step), + resumed_offset + (3 * resumed_record_step), + ] + expected_record_times = [ + time for time in expected_record_times if time < (end_time - base_time) + ] + + assert resumed_record_times[: len(expected_record_times)] == expected_record_times + + reference = build_simulation(original_config, domain_5by5) + run_simulation_to_end(reference) + assert_final_state_matches(sim_b, reference) diff --git a/tests/core/test_hotstart_state_loading.py b/tests/core/test_hotstart_state_loading.py index 7740719..46b2dfc 100644 --- a/tests/core/test_hotstart_state_loading.py +++ b/tests/core/test_hotstart_state_loading.py @@ -128,6 +128,39 @@ def test_load_state_rejects_dtype_mismatch( class TestSimulationBuilderHotstart: """Tests for SimulationBuilder hotstart integration.""" + @staticmethod + def _create_hotstart_bytes(domain_5by5, sim_config: SimulationConfig) -> io.BytesIO: + raster_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) + simulation = ( + SimulationBuilder(sim_config, domain_5by5.arr_mask, np.float32) + .with_domain_data(domain_5by5.domain_data) + .with_raster_output_provider(raster_output) + .with_vector_output_provider(MemoryVectorOutputProvider({})) + .build() + ) + simulation.set_array("dem", domain_5by5.arr_dem_flat.copy()) + simulation.set_array("friction", domain_5by5.arr_n.copy()) + simulation.set_array("water_depth", domain_5by5.arr_start_h.copy()) + simulation.initialize() + return simulation.create_hotstart() + + @staticmethod + def _build_with_hotstart( + domain_5by5, + sim_config: SimulationConfig, + hotstart_bytes: io.BytesIO, + domain_data: DomainData | None = None, + ): + raster_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) + builder = ( + SimulationBuilder(sim_config, domain_5by5.arr_mask, np.float32) + .with_domain_data(domain_data or domain_5by5.domain_data) + .with_raster_output_provider(raster_output) + .with_vector_output_provider(MemoryVectorOutputProvider({})) + .with_hotstart(hotstart_bytes) + ) + return builder.build() + @pytest.fixture def sim_config(self, helpers) -> SimulationConfig: """Create a basic SimulationConfig for testing.""" @@ -156,31 +189,7 @@ def valid_hotstart_bytes( sim_config: SimulationConfig, ) -> io.BytesIO: """Create a valid hotstart archive for testing.""" - raster_domain = RasterDomain( - dtype=np.float32, - arr_mask=domain_5by5.arr_mask, - cell_shape=domain_5by5.domain_data.cell_shape, - ) - raster_domain.update_array("water_depth", domain_5by5.arr_start_h.copy()) - raster_domain.update_array("dem", domain_5by5.arr_dem_flat.copy()) - raster_domain.update_array("friction", domain_5by5.arr_n.copy()) - - raster_state = raster_domain.save_state() - simulation_state = HotstartSimulationState( - sim_time="2000-01-01T00:00:30", - dt=0.5, - next_ts={}, - time_steps_counters={"since_start": 100, "since_last_report": 10}, - accum_update_time={}, - old_domain_volume=100.0, - ) - - return HotstartWriter.create( - domain_data=domain_5by5.domain_data, - simulation_config=sim_config, - simulation_state=simulation_state, - raster_state_bytes=raster_state.getvalue(), - ) + return self._create_hotstart_bytes(domain_5by5, sim_config) def test_with_hotstart_from_bytes( self, @@ -260,18 +269,88 @@ def test_build_rejects_domain_mismatch( kwargs[field] = value mismatched_domain = DomainData(**kwargs) - raster_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) - with pytest.raises(HotstartError, match=expected_error): - ( - SimulationBuilder(sim_config, domain_5by5.arr_mask, np.float32) - .with_domain_data(mismatched_domain) - .with_raster_output_provider(raster_output) - .with_vector_output_provider(MemoryVectorOutputProvider({})) - .with_hotstart(valid_hotstart_bytes) - .build() + self._build_with_hotstart( + domain_5by5, + sim_config, + valid_hotstart_bytes, + domain_data=mismatched_domain, ) + def test_build_rejects_crs_mismatch( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + ) -> None: + """build() should reject domain CRS mismatches.""" + dd = domain_5by5.domain_data + mismatched_domain = DomainData( + north=dd.north, + south=dd.south, + east=dd.east, + west=dd.west, + rows=dd.rows, + cols=dd.cols, + crs_wkt="EPSG:4326", + ) + + with pytest.raises(HotstartError, match="Domain CRS mismatch"): + self._build_with_hotstart( + domain_5by5, + sim_config, + valid_hotstart_bytes, + domain_data=mismatched_domain, + ) + + @pytest.mark.parametrize( + ("archived_model", "resumed_model"), + [ + (InfiltrationModelType.GREEN_AMPT, InfiltrationModelType.NULL), + (InfiltrationModelType.NULL, InfiltrationModelType.GREEN_AMPT), + ], + ) + def test_build_rejects_infiltration_model_mismatch( + self, + domain_5by5, + sim_config: SimulationConfig, + archived_model: InfiltrationModelType, + resumed_model: InfiltrationModelType, + ) -> None: + """build() should reject hotstarts with incompatible infiltration models.""" + archived_config = sim_config.model_copy(update={"infiltration_model": archived_model}) + resumed_config = sim_config.model_copy(update={"infiltration_model": resumed_model}) + hotstart_bytes = self._create_hotstart_bytes(domain_5by5, archived_config) + + with pytest.raises(HotstartError, match="infiltration"): + self._build_with_hotstart(domain_5by5, resumed_config, hotstart_bytes) + + @pytest.mark.parametrize( + ("parameter", "value"), + [ + ("hmin", 0.0002), + ], + ) + def test_build_rejects_surface_flow_parameter_mismatch( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + parameter: str, + value: float, + ) -> None: + """build() should reject solver-affecting surface-flow parameter changes.""" + resumed_config = sim_config.model_copy( + update={ + "surface_flow_parameters": sim_config.surface_flow_parameters.model_copy( + update={parameter: value} + ) + } + ) + + with pytest.raises(HotstartError, match=f"{parameter}|surface"): + self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + def test_build_rejects_drainage_mismatch_hotstart_has_drainage( self, domain_5by5, From 244e1f763bcc2011dfc3d93055635947a8cb9647 Mon Sep 17 00:00:00 2001 From: Laurent Courty Date: Fri, 27 Mar 2026 16:24:10 -0600 Subject: [PATCH 2/2] extend hotstart test coverage --- src/itzi/data_containers.py | 33 +-- src/itzi/simulation.py | 24 ++ src/itzi/simulation_builder.py | 75 +++++- tests/core/test_hotstart_integration.py | 278 ++++++++++++---------- tests/core/test_hotstart_state_loading.py | 105 ++++++++ 5 files changed, 375 insertions(+), 140 deletions(-) diff --git a/src/itzi/data_containers.py b/src/itzi/data_containers.py index f29f8ec..f479c18 100644 --- a/src/itzi/data_containers.py +++ b/src/itzi/data_containers.py @@ -20,6 +20,7 @@ import numpy as np from pydantic import BaseModel, ConfigDict +from pydantic import PositiveFloat, NonNegativeFloat, NonNegativeInt, Field from itzi.const import DefaultValues, TemporalType, InfiltrationModelType import itzi.messenger as msgr @@ -150,8 +151,8 @@ class SimulationData(BaseModel): continuity_data: ContinuityData | None # Made optional for use in tests raw_arrays: Dict[str, np.ndarray] accumulation_arrays: Dict[str, np.ndarray] - cell_dx: float # cell size in east-west direction - cell_dy: float # cell size in north-south direction + cell_dx: PositiveFloat # cell size in east-west direction + cell_dy: PositiveFloat # cell size in north-south direction drainage_network_data: DrainageNetworkData | None @@ -162,7 +163,7 @@ class MassBalanceData(BaseModel): simulation_time: datetime | timedelta average_timestep: float - timesteps: int + timesteps: NonNegativeInt boundary_volume: float rainfall_volume: float infiltration_volume: float @@ -180,15 +181,15 @@ class SurfaceFlowParameters(BaseModel): model_config = ConfigDict(frozen=True) - hmin: float = DefaultValues.HFMIN - cfl: float = DefaultValues.CFL - theta: float = DefaultValues.THETA - g: float = DefaultValues.G - vrouting: float = DefaultValues.VROUTING - dtmax: float = DefaultValues.DTMAX - slope_threshold: float = DefaultValues.SLOPE_THRESHOLD - max_slope: float = DefaultValues.MAX_SLOPE - max_error: float = DefaultValues.MAX_ERROR + hmin: NonNegativeFloat = DefaultValues.HFMIN + cfl: PositiveFloat = Field(DefaultValues.CFL, ge=0.01, le=1) + theta: NonNegativeFloat = Field(DefaultValues.THETA, ge=0, le=1) + g: NonNegativeFloat = DefaultValues.G + vrouting: NonNegativeFloat = DefaultValues.VROUTING + dtmax: PositiveFloat = DefaultValues.DTMAX + slope_threshold: NonNegativeFloat = DefaultValues.SLOPE_THRESHOLD + max_slope: NonNegativeFloat = DefaultValues.MAX_SLOPE + max_error: PositiveFloat = DefaultValues.MAX_ERROR class SimulationConfig(BaseModel): @@ -209,14 +210,14 @@ class SimulationConfig(BaseModel): # Mass balance file stats_file: str | Path | None = None # Hydrology parameters - dtinf: float = DefaultValues.DTINF + dtinf: PositiveFloat = DefaultValues.DTINF infiltration_model: InfiltrationModelType = InfiltrationModelType.NULL # Drainage parameters swmm_inp: str | None = None drainage_output: str | None = None - orifice_coeff: float = DefaultValues.ORIFICE_COEFF - free_weir_coeff: float = DefaultValues.FREE_WEIR_COEFF - submerged_weir_coeff: float = DefaultValues.SUBMERGED_WEIR_COEFF + orifice_coeff: NonNegativeFloat = Field(DefaultValues.ORIFICE_COEFF, ge=0, le=1) + free_weir_coeff: NonNegativeFloat = Field(DefaultValues.FREE_WEIR_COEFF, ge=0, le=1) + submerged_weir_coeff: NonNegativeFloat = Field(DefaultValues.SUBMERGED_WEIR_COEFF, ge=0, le=1) def as_str_dict(self) -> Dict: """Convert the configuration to a dictionary with string representations.""" diff --git a/src/itzi/simulation.py b/src/itzi/simulation.py index 63341f0..38ad899 100644 --- a/src/itzi/simulation.py +++ b/src/itzi/simulation.py @@ -569,3 +569,27 @@ def restore_state(self, simulation_state: HotstartSimulationState) -> Self: self.nextstep = min(self.next_ts.values()) return self + + def reconcile_hotstart_resume(self, hotstart_config: SimulationConfig) -> Self: + """Apply resume-time config changes allowed after hotstart restoration.""" + schedule_changed = False + + if self.end_time != hotstart_config.end_time: + self.next_ts["end"] = self.end_time + if not self.drainage_model: + self.next_ts["drainage"] = self.end_time + schedule_changed = True + + if self.report.dt != hotstart_config.record_step: + self.next_ts["record"] = min(self.end_time, self.sim_time + self.report.dt) + self.report.last_step = copy.copy(self.sim_time) + schedule_changed = True + + if self.hydrology_model.dt != timedelta(seconds=hotstart_config.dtinf): + self.next_ts["hydrology"] = min(self.end_time, self.sim_time + self.hydrology_model.dt) + schedule_changed = True + + if schedule_changed: + self.nextstep = min(self.next_ts.values()) + + return self diff --git a/src/itzi/simulation_builder.py b/src/itzi/simulation_builder.py index ab0a70b..7f7f06d 100644 --- a/src/itzi/simulation_builder.py +++ b/src/itzi/simulation_builder.py @@ -15,7 +15,7 @@ from __future__ import annotations import tempfile -from datetime import timedelta +from datetime import datetime, timedelta from typing import TYPE_CHECKING import io from pathlib import Path @@ -43,12 +43,24 @@ if TYPE_CHECKING: from itzi.providers.domain_data import DomainData from itzi.providers.base import RasterInputProvider, RasterOutputProvider, VectorOutputProvider - from itzi.data_containers import SimulationConfig + from itzi.data_containers import ( + SimulationConfig, + HotstartSimulationState, + SurfaceFlowParameters, + ) class SimulationBuilder: """Builder for creating Simulation objects with different provider configurations.""" + _ALLOWED_RESUME_SURFACE_FLOW_CHANGES = { + "cfl", + "theta", + "dtmax", + "slope_threshold", + "max_slope", + } + def __init__( self, sim_config: SimulationConfig, @@ -127,6 +139,7 @@ def _validate_hotstart_congruence(self) -> None: hotstart_domain = self.hotstart_loader.get_domain_data() hotstart_config = self.hotstart_loader.get_simulation_config() + hotstart_state = self.hotstart_loader.get_simulation_state() # Validate domain metadata self._validate_domain_congruence(hotstart_domain) @@ -137,6 +150,63 @@ def _validate_hotstart_congruence(self) -> None: # Validate drainage expectations self._validate_drainage_congruence(hotstart_config) + # Validate resume-time configuration compatibility + self._validate_resume_config_congruence(hotstart_config, hotstart_state) + + def _validate_resume_config_congruence( + self, + hotstart_config: SimulationConfig, + hotstart_state: HotstartSimulationState, + ) -> None: + """Validate which runtime settings may change across a hotstart resume.""" + hotstart_sim_time = datetime.fromisoformat(hotstart_state.sim_time) + + # Keep this defensive check here even though SimulationConfig also validates + # user input: model_copy(update=...) can bypass Pydantic validation in tests + # and internal resume flows. + if self.sim_config.record_step <= timedelta(0): + raise HotstartError( + f"Resume record_step must be positive, not {self.sim_config.record_step}" + ) + + if ( + self.sim_config.end_time != hotstart_config.end_time + and self.sim_config.end_time <= hotstart_sim_time + ): + raise HotstartError( + "Resume end_time must be strictly after the hotstart simulation time: " + f"end_time={self.sim_config.end_time}, hotstart_sim_time={hotstart_sim_time}" + ) + + if self.sim_config.infiltration_model != hotstart_config.infiltration_model: + raise HotstartError( + "Hotstart infiltration model mismatch: " + f"current={self.sim_config.infiltration_model}, " + f"hotstart={hotstart_config.infiltration_model}" + ) + + self._validate_surface_flow_parameter_congruence(hotstart_config.surface_flow_parameters) + + def _validate_surface_flow_parameter_congruence( + self, + hotstart_surface_flow_parameters: SurfaceFlowParameters, + ) -> None: + """Validate the subset of surface-flow parameters that must not change.""" + current_surface_flow_parameters = self.sim_config.surface_flow_parameters + + for field_name in type(current_surface_flow_parameters).model_fields: + if field_name in self._ALLOWED_RESUME_SURFACE_FLOW_CHANGES: + continue + + current_value = getattr(current_surface_flow_parameters, field_name) + hotstart_value = getattr(hotstart_surface_flow_parameters, field_name) + + if not np.isclose(current_value, hotstart_value): + raise HotstartError( + "Surface flow parameter mismatch for " + f"{field_name}: current={current_value}, hotstart={hotstart_value}" + ) + def _validate_domain_congruence(self, hotstart_domain: DomainData) -> None: """Validate that domain metadata matches between hotstart and builder.""" assert self.domain_data is not None # Already validated in build() @@ -315,6 +385,7 @@ def build(self) -> Simulation: # Restore simulation runtime/scheduler state simulation_state = self.hotstart_loader.get_simulation_state() simulation.restore_state(simulation_state) + simulation.reconcile_hotstart_resume(self.hotstart_loader.get_simulation_config()) # Restore drainage coupling state from the saved n_drain raster. # DrainageNode.coupling_flow is always 0 after object creation, and diff --git a/tests/core/test_hotstart_integration.py b/tests/core/test_hotstart_integration.py index aae67c1..3db003f 100644 --- a/tests/core/test_hotstart_integration.py +++ b/tests/core/test_hotstart_integration.py @@ -119,21 +119,39 @@ def build_simulation( return simulation -def run_to_split_and_create_hotstart( - sim_config: SimulationConfig, - domain_5by5, - split_time: datetime, - raster_output_provider=None, -) -> tuple[Simulation, bytes]: - simulation = build_simulation( - sim_config, - domain_5by5, - raster_output_provider=raster_output_provider, - ) +def capture_simulation_snapshot(simulation: Simulation) -> dict: + return { + "sim_time": simulation.sim_time, + "dt": simulation.dt, + "time_steps_counters": dict(simulation.time_steps_counters), + "old_domain_volume": simulation.old_domain_volume, + "next_ts": {key: value for key, value in simulation.next_ts.items()}, + "accum_update_time": {key: value for key, value in simulation.accum_update_time.items()}, + "raster_state": { + key: simulation.raster_domain.get_array(key).copy() + for key in simulation.raster_domain.k_all + }, + } + + +def run_with_hotstart_checkpoints( + simulation: Simulation, + checkpoints: list[tuple[str, datetime]], +) -> dict[str, dict]: + pending_checkpoints = sorted(checkpoints, key=lambda item: item[1]) + captured: dict[str, dict] = {} + simulation.initialize() - while simulation.sim_time < split_time: + while simulation.sim_time < simulation.end_time: simulation.update() - return simulation, simulation.create_hotstart().getvalue() + while pending_checkpoints and simulation.sim_time >= pending_checkpoints[0][1]: + name, _ = pending_checkpoints.pop(0) + captured[name] = { + "snapshot": capture_simulation_snapshot(simulation), + "hotstart_bytes": simulation.create_hotstart().getvalue(), + } + simulation.finalize() + return captured def assert_final_state_matches(simulation: Simulation, reference: Simulation) -> None: @@ -143,6 +161,17 @@ def assert_final_state_matches(simulation: Simulation, reference: Simulation) -> np.testing.assert_allclose(arr_resumed, arr_reference, err_msg=f"Final {key} mismatch") +def assert_state_differs(simulation: Simulation, reference: Simulation) -> None: + mismatch_found = False + for key in ["water_depth", "qe", "qs"]: + arr_resumed = simulation.raster_domain.get_array(key) + arr_reference = reference.raster_domain.get_array(key) + if not np.allclose(arr_resumed, arr_reference): + mismatch_found = True + break + assert mismatch_found, "Expected resumed state to differ from the archived-cadence reference" + + def get_unique_record_times( output_provider: MemoryRasterOutputProvider, output_key: str = "water_depth", @@ -161,11 +190,13 @@ def base_time() -> datetime: @pytest.fixture(scope="module") -def uninterrupted_simulation(domain_5by5, helpers, base_time) -> Simulation: - """Run the full simulation once for reference. +def uninterrupted_simulation(baseline_hotstart_run) -> Simulation: + """Reuse the shared 60-second baseline run as the final-state reference.""" + return baseline_hotstart_run["simulation"] - This simulation runs from t=0 to t=60s without interruption. - """ + +@pytest.fixture(scope="module") +def baseline_hotstart_run(domain_5by5, helpers, base_time) -> dict: end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) sim_config = create_sim_config( start_time=base_time, @@ -173,8 +204,51 @@ def uninterrupted_simulation(domain_5by5, helpers, base_time) -> Simulation: helpers=helpers, ) simulation = build_simulation(sim_config, domain_5by5) + checkpoints = run_with_hotstart_checkpoints( + simulation, + [ + ("split_10", base_time + timedelta(seconds=10)), + ("split_30", base_time + timedelta(seconds=30)), + ("split_50", base_time + timedelta(seconds=50)), + ], + ) + return { + "sim_config": sim_config, + "simulation": simulation, + "checkpoints": checkpoints, + } + + +@pytest.fixture(scope="module") +def extended_reference_simulation(domain_5by5, helpers, base_time) -> dict: + sim_config = create_sim_config( + start_time=base_time, + end_time=base_time + timedelta(seconds=90), + helpers=helpers, + ) + simulation = build_simulation(sim_config, domain_5by5) run_simulation_to_end(simulation) - return simulation + return {"sim_config": sim_config, "simulation": simulation} + + +@pytest.fixture(scope="module") +def record_step_hotstart_run(domain_5by5, helpers, base_time) -> dict: + end_time = base_time + timedelta(seconds=70) + sim_config = create_sim_config( + start_time=base_time, + end_time=end_time, + helpers=helpers, + ) + simulation = build_simulation(sim_config, domain_5by5) + checkpoints = run_with_hotstart_checkpoints( + simulation, + [("split_34_2", base_time + timedelta(seconds=34.2))], + ) + return { + "sim_config": sim_config, + "simulation": simulation, + "checkpoints": checkpoints, + } # @pytest.mark.skip(reason="Hotstart final state comparison fails due to final results diverging " @@ -184,6 +258,7 @@ def test_roundtrip_state_restoration_and_match( domain_5by5, helpers, base_time, + baseline_hotstart_run, uninterrupted_simulation: Simulation, split_seconds: int, ) -> None: @@ -195,33 +270,11 @@ def test_roundtrip_state_restoration_and_match( 3. Runs resumed simulation to end 4. Verifies final results match the uninterrupted reference """ - split_time = base_time + timedelta(seconds=split_seconds) end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) - # Step 1: Run simulation A to split point and create hotstart - sim_a_config = create_sim_config( - start_time=base_time, - end_time=end_time, - helpers=helpers, - ) - sim_a, hotstart_bytes = run_to_split_and_create_hotstart( - sim_a_config, - domain_5by5, - split_time, - ) - - # Save state for comparison - saved_sim_time = sim_a.sim_time - saved_dt = sim_a.dt - saved_counters = dict(sim_a.time_steps_counters) - saved_old_domain_volume = sim_a.old_domain_volume - saved_next_ts = {k: v for k, v in sim_a.next_ts.items()} - saved_accum_update_time = {k: v for k, v in sim_a.accum_update_time.items()} - - # Save raster state for comparison - saved_raster_state = {} - for key in sim_a.raster_domain.k_all: - saved_raster_state[key] = sim_a.raster_domain.get_array(key).copy() + checkpoint = baseline_hotstart_run["checkpoints"][f"split_{split_seconds}"] + saved_snapshot = checkpoint["snapshot"] + hotstart_bytes = checkpoint["hotstart_bytes"] # Step 2: Create simulation B with hotstart and verify state restoration sim_b_config = create_sim_config( @@ -232,9 +285,8 @@ def test_roundtrip_state_restoration_and_match( sim_b = build_simulation(sim_b_config, domain_5by5, hotstart_bytes=hotstart_bytes) # Verify raster state was restored - for key in saved_raster_state: + for key, arr_saved in saved_snapshot["raster_state"].items(): arr_restored = sim_b.raster_domain.get_array(key) - arr_saved = saved_raster_state[key] np.testing.assert_allclose( arr_restored, arr_saved, @@ -244,30 +296,33 @@ def test_roundtrip_state_restoration_and_match( ) # Verify scheduler state was restored - assert sim_b.sim_time == saved_sim_time, ( - f"sim_time not restored: {sim_b.sim_time} != {saved_sim_time}" + assert sim_b.sim_time == saved_snapshot["sim_time"], ( + f"sim_time not restored: {sim_b.sim_time} != {saved_snapshot['sim_time']}" + ) + assert sim_b.dt == saved_snapshot["dt"], ( + f"dt not restored: {sim_b.dt} != {saved_snapshot['dt']}" ) - assert sim_b.dt == saved_dt, f"dt not restored: {sim_b.dt} != {saved_dt}" - assert sim_b.time_steps_counters == saved_counters, ( - f"time_steps_counters not restored: {sim_b.time_steps_counters} != {saved_counters}" + assert sim_b.time_steps_counters == saved_snapshot["time_steps_counters"], ( + "time_steps_counters not restored: " + f"{sim_b.time_steps_counters} != {saved_snapshot['time_steps_counters']}" ) - assert np.isclose(sim_b.old_domain_volume, saved_old_domain_volume), ( - f"old_domain_volume not restored: {sim_b.old_domain_volume} != {saved_old_domain_volume}" + assert np.isclose(sim_b.old_domain_volume, saved_snapshot["old_domain_volume"]), ( + "old_domain_volume not restored: " + f"{sim_b.old_domain_volume} != {saved_snapshot['old_domain_volume']}" ) # Verify next_ts schedule was restored - for key in saved_next_ts: + for key, value in saved_snapshot["next_ts"].items(): assert key in sim_b.next_ts, f"next_ts key '{key}' missing in restored simulation" - assert sim_b.next_ts[key] == saved_next_ts[key], ( - f"next_ts[{key}] not restored: {sim_b.next_ts[key]} != {saved_next_ts[key]}" + assert sim_b.next_ts[key] == value, ( + f"next_ts[{key}] not restored: {sim_b.next_ts[key]} != {value}" ) # Verify accum_update_time was restored - for key in saved_accum_update_time: + for key, value in saved_snapshot["accum_update_time"].items(): assert key in sim_b.accum_update_time, f"accum_update_time key '{key}' missing" - assert sim_b.accum_update_time[key] == saved_accum_update_time[key], ( - f"accum_update_time[{key}] not restored: " - f"{sim_b.accum_update_time[key]} != {saved_accum_update_time[key]}" + assert sim_b.accum_update_time[key] == value, ( + f"accum_update_time[{key}] not restored: {sim_b.accum_update_time[key]} != {value}" ) # Verify nextstep is consistent with next_ts (scheduler invariant) @@ -293,25 +348,11 @@ def test_roundtrip_state_restoration_and_match( def test_resume_allows_output_provider_change( domain_5by5, - helpers, - base_time, + baseline_hotstart_run, uninterrupted_simulation: Simulation, ) -> None: - split_time = base_time + timedelta(seconds=30) - end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) - - sim_config = create_sim_config( - start_time=base_time, - end_time=end_time, - helpers=helpers, - ) - initial_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) - _, hotstart_bytes = run_to_split_and_create_hotstart( - sim_config, - domain_5by5, - split_time, - raster_output_provider=initial_output, - ) + sim_config = baseline_hotstart_run["sim_config"] + hotstart_bytes = baseline_hotstart_run["checkpoints"]["split_30"]["hotstart_bytes"] resumed_output = MemoryRasterOutputProvider({"out_map_names": sim_config.output_map_names}) sim_b = build_simulation( @@ -323,28 +364,20 @@ def test_resume_allows_output_provider_change( run_simulation_to_end(sim_b, skip_initialize=True) - assert initial_output is not resumed_output assert resumed_output.output_maps_dict["water_depth"] assert len(resumed_output.output_maps_dict["water_depth"]) >= 1 assert_final_state_matches(sim_b, uninterrupted_simulation) - assert sim_b.sim_time == end_time + assert sim_b.sim_time == sim_config.end_time def test_resume_allows_output_map_name_change( domain_5by5, helpers, - base_time, + baseline_hotstart_run, uninterrupted_simulation: Simulation, ) -> None: - split_time = base_time + timedelta(seconds=30) - end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) - - sim_a_config = create_sim_config( - start_time=base_time, - end_time=end_time, - helpers=helpers, - ) - _, hotstart_bytes = run_to_split_and_create_hotstart(sim_a_config, domain_5by5, split_time) + sim_a_config = baseline_hotstart_run["sim_config"] + hotstart_bytes = baseline_hotstart_run["checkpoints"]["split_30"]["hotstart_bytes"] resumed_output_map_names = helpers.make_output_map_names( "out_resume", @@ -368,46 +401,33 @@ def test_resume_allows_output_map_name_change( assert_final_state_matches(sim_b, uninterrupted_simulation) -def test_resume_allows_end_time_extension(domain_5by5, helpers, base_time) -> None: - split_time = base_time + timedelta(seconds=30) - original_end_time = base_time + timedelta(seconds=TOTAL_DURATION_SECONDS) +def test_resume_allows_end_time_extension( + domain_5by5, + base_time, + baseline_hotstart_run, + extended_reference_simulation, +) -> None: extended_end_time = base_time + timedelta(seconds=90) - sim_a_config = create_sim_config( - start_time=base_time, - end_time=original_end_time, - helpers=helpers, - ) - _, hotstart_bytes = run_to_split_and_create_hotstart(sim_a_config, domain_5by5, split_time) - - sim_b_config = create_sim_config( - start_time=base_time, - end_time=extended_end_time, - helpers=helpers, - ) + hotstart_bytes = baseline_hotstart_run["checkpoints"]["split_30"]["hotstart_bytes"] + sim_b_config = extended_reference_simulation["sim_config"] sim_b = build_simulation(sim_b_config, domain_5by5, hotstart_bytes=hotstart_bytes) - reference = build_simulation(sim_b_config, domain_5by5) - run_simulation_to_end(reference) run_simulation_to_end(sim_b, skip_initialize=True) assert sim_b.sim_time == extended_end_time - assert_final_state_matches(sim_b, reference) + assert_final_state_matches(sim_b, extended_reference_simulation["simulation"]) -def test_resume_applies_new_record_step_cadence(domain_5by5, helpers, base_time) -> None: - split_time = base_time + timedelta(seconds=34.2) +def test_resume_applies_new_record_step_cadence( + domain_5by5, + base_time, + record_step_hotstart_run, +) -> None: end_time = base_time + timedelta(seconds=70) - original_config = create_sim_config( - start_time=base_time, - end_time=end_time, - helpers=helpers, - ) - sim_a, hotstart_bytes = run_to_split_and_create_hotstart( - original_config, - domain_5by5, - split_time, - ) + original_config = record_step_hotstart_run["sim_config"] + checkpoint = record_step_hotstart_run["checkpoints"]["split_34_2"] + hotstart_bytes = checkpoint["hotstart_bytes"] resumed_record_step = timedelta(seconds=10) resumed_config = original_config.model_copy(update={"record_step": resumed_record_step}) @@ -422,7 +442,7 @@ def test_resume_applies_new_record_step_cadence(domain_5by5, helpers, base_time) run_simulation_to_end(sim_b, skip_initialize=True) resumed_record_times = get_unique_record_times(resumed_output) - resumed_offset = sim_a.sim_time - base_time + resumed_offset = checkpoint["snapshot"]["sim_time"] - base_time expected_record_times = [ resumed_offset + resumed_record_step, resumed_offset + (2 * resumed_record_step), @@ -434,6 +454,20 @@ def test_resume_applies_new_record_step_cadence(domain_5by5, helpers, base_time) assert resumed_record_times[: len(expected_record_times)] == expected_record_times - reference = build_simulation(original_config, domain_5by5) - run_simulation_to_end(reference) - assert_final_state_matches(sim_b, reference) + assert_state_differs(sim_b, record_step_hotstart_run["simulation"]) + + +def test_resume_applies_new_dtinf_to_hydrology_schedule( + domain_5by5, + base_time, + baseline_hotstart_run, +) -> None: + hotstart_bytes = baseline_hotstart_run["checkpoints"]["split_30"]["hotstart_bytes"] + original_config = baseline_hotstart_run["sim_config"] + resumed_dtinf = 5.0 + resumed_config = original_config.model_copy(update={"dtinf": resumed_dtinf}) + + sim_b = build_simulation(resumed_config, domain_5by5, hotstart_bytes=hotstart_bytes) + + assert sim_b.hydrology_model.dt == timedelta(seconds=resumed_dtinf) + assert sim_b.next_ts["hydrology"] == sim_b.sim_time + timedelta(seconds=resumed_dtinf) diff --git a/tests/core/test_hotstart_state_loading.py b/tests/core/test_hotstart_state_loading.py index 46b2dfc..bb56418 100644 --- a/tests/core/test_hotstart_state_loading.py +++ b/tests/core/test_hotstart_state_loading.py @@ -325,10 +325,115 @@ def test_build_rejects_infiltration_model_mismatch( with pytest.raises(HotstartError, match="infiltration"): self._build_with_hotstart(domain_5by5, resumed_config, hotstart_bytes) + def test_build_allows_record_step_change( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + ) -> None: + """build() should allow changing the output cadence on resume.""" + resumed_record_step = timedelta(seconds=10) + resumed_config = sim_config.model_copy(update={"record_step": resumed_record_step}) + + simulation = self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + assert simulation.report.dt == resumed_record_step + + def test_build_rejects_end_time_not_after_hotstart_time( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + ) -> None: + """build() should reject resumed end times that are not after the checkpoint time.""" + resumed_config = sim_config.model_copy(update={"end_time": sim_config.start_time}) + + with pytest.raises(HotstartError, match="end_time"): + self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + def test_build_allows_end_time_change_after_hotstart_time( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + ) -> None: + """build() should allow resumed end times after the checkpoint time.""" + resumed_end_time = sim_config.start_time + timedelta(seconds=45) + resumed_config = sim_config.model_copy(update={"end_time": resumed_end_time}) + + simulation = self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + assert simulation.end_time == resumed_end_time + + @pytest.mark.parametrize("record_step", [timedelta(0), timedelta(seconds=-1)]) + def test_build_rejects_non_positive_record_step( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + record_step: timedelta, + ) -> None: + """build() should reject resumed record steps that are not positive. + + model_copy(update=...) bypasses Pydantic validation, so the builder keeps + a defensive runtime check for this invariant. + """ + resumed_config = sim_config.model_copy(update={"record_step": record_step}) + + with pytest.raises(HotstartError, match="record_step"): + self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + def test_build_allows_dtinf_change( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + ) -> None: + """build() should allow changing dtinf on resume.""" + resumed_dtinf = 5.0 + resumed_config = sim_config.model_copy(update={"dtinf": resumed_dtinf}) + + simulation = self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + assert simulation.hydrology_model.dt == timedelta(seconds=resumed_dtinf) + + @pytest.mark.parametrize( + ("parameter", "value"), + [ + ("cfl", 0.15), + ("theta", 0.9), + ("dtmax", 0.2), + ("slope_threshold", 1e-5), + ("max_slope", 5.0), + ], + ) + def test_build_allows_surface_flow_resume_parameter_change( + self, + domain_5by5, + sim_config: SimulationConfig, + valid_hotstart_bytes: io.BytesIO, + parameter: str, + value: float, + ) -> None: + """build() should allow selected surface-flow tuning changes on resume.""" + resumed_surface_flow_parameters = sim_config.surface_flow_parameters.model_copy( + update={parameter: value} + ) + resumed_config = sim_config.model_copy( + update={"surface_flow_parameters": resumed_surface_flow_parameters} + ) + + simulation = self._build_with_hotstart(domain_5by5, resumed_config, valid_hotstart_bytes) + + assert getattr(simulation.surface_flow, parameter) == value + @pytest.mark.parametrize( ("parameter", "value"), [ ("hmin", 0.0002), + ("g", 9.9), + ("vrouting", 2.0), + ("max_error", 0.5), ], ) def test_build_rejects_surface_flow_parameter_mismatch(