Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/itzi/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
from pydantic import BaseModel, ConfigDict
from pydantic import PositiveFloat, NonNegativeFloat, NonNegativeInt, Field

from itzi.const import DefaultValues, TemporalType, InfiltrationModelType
import itzi.messenger as msgr
Expand Down Expand Up @@ -150,8 +151,8 @@ class SimulationData(BaseModel):
continuity_data: ContinuityData | None # Made optional for use in tests
raw_arrays: Dict[str, np.ndarray]
accumulation_arrays: Dict[str, np.ndarray]
cell_dx: float # cell size in east-west direction
cell_dy: float # cell size in north-south direction
cell_dx: PositiveFloat # cell size in east-west direction
cell_dy: PositiveFloat # cell size in north-south direction
drainage_network_data: DrainageNetworkData | None


Expand All @@ -162,7 +163,7 @@ class MassBalanceData(BaseModel):

simulation_time: datetime | timedelta
average_timestep: float
timesteps: int
timesteps: NonNegativeInt
boundary_volume: float
rainfall_volume: float
infiltration_volume: float
Expand All @@ -180,15 +181,15 @@ class SurfaceFlowParameters(BaseModel):

model_config = ConfigDict(frozen=True)

hmin: float = DefaultValues.HFMIN
cfl: float = DefaultValues.CFL
theta: float = DefaultValues.THETA
g: float = DefaultValues.G
vrouting: float = DefaultValues.VROUTING
dtmax: float = DefaultValues.DTMAX
slope_threshold: float = DefaultValues.SLOPE_THRESHOLD
max_slope: float = DefaultValues.MAX_SLOPE
max_error: float = DefaultValues.MAX_ERROR
hmin: NonNegativeFloat = DefaultValues.HFMIN
cfl: PositiveFloat = Field(DefaultValues.CFL, ge=0.01, le=1)
theta: NonNegativeFloat = Field(DefaultValues.THETA, ge=0, le=1)
g: NonNegativeFloat = DefaultValues.G
vrouting: NonNegativeFloat = DefaultValues.VROUTING
dtmax: PositiveFloat = DefaultValues.DTMAX
slope_threshold: NonNegativeFloat = DefaultValues.SLOPE_THRESHOLD
max_slope: NonNegativeFloat = DefaultValues.MAX_SLOPE
max_error: PositiveFloat = DefaultValues.MAX_ERROR


class SimulationConfig(BaseModel):
Expand All @@ -209,14 +210,14 @@ class SimulationConfig(BaseModel):
# Mass balance file
stats_file: str | Path | None = None
# Hydrology parameters
dtinf: float = DefaultValues.DTINF
dtinf: PositiveFloat = DefaultValues.DTINF
infiltration_model: InfiltrationModelType = InfiltrationModelType.NULL
# Drainage parameters
swmm_inp: str | None = None
drainage_output: str | None = None
orifice_coeff: float = DefaultValues.ORIFICE_COEFF
free_weir_coeff: float = DefaultValues.FREE_WEIR_COEFF
submerged_weir_coeff: float = DefaultValues.SUBMERGED_WEIR_COEFF
orifice_coeff: NonNegativeFloat = Field(DefaultValues.ORIFICE_COEFF, ge=0, le=1)
free_weir_coeff: NonNegativeFloat = Field(DefaultValues.FREE_WEIR_COEFF, ge=0, le=1)
submerged_weir_coeff: NonNegativeFloat = Field(DefaultValues.SUBMERGED_WEIR_COEFF, ge=0, le=1)

def as_str_dict(self) -> Dict:
"""Convert the configuration to a dictionary with string representations."""
Expand Down
24 changes: 24 additions & 0 deletions src/itzi/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,27 @@ def restore_state(self, simulation_state: HotstartSimulationState) -> Self:
self.nextstep = min(self.next_ts.values())

return self

def reconcile_hotstart_resume(self, hotstart_config: SimulationConfig) -> Self:
"""Apply resume-time config changes allowed after hotstart restoration."""
schedule_changed = False

if self.end_time != hotstart_config.end_time:
self.next_ts["end"] = self.end_time
if not self.drainage_model:
self.next_ts["drainage"] = self.end_time
schedule_changed = True

if self.report.dt != hotstart_config.record_step:
self.next_ts["record"] = min(self.end_time, self.sim_time + self.report.dt)
self.report.last_step = copy.copy(self.sim_time)
schedule_changed = True

if self.hydrology_model.dt != timedelta(seconds=hotstart_config.dtinf):
self.next_ts["hydrology"] = min(self.end_time, self.sim_time + self.hydrology_model.dt)
schedule_changed = True

if schedule_changed:
self.nextstep = min(self.next_ts.values())

return self
75 changes: 73 additions & 2 deletions src/itzi/simulation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import tempfile
from datetime import timedelta
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
import io
from pathlib import Path
Expand Down Expand Up @@ -43,12 +43,24 @@
if TYPE_CHECKING:
from itzi.providers.domain_data import DomainData
from itzi.providers.base import RasterInputProvider, RasterOutputProvider, VectorOutputProvider
from itzi.data_containers import SimulationConfig
from itzi.data_containers import (
SimulationConfig,
HotstartSimulationState,
SurfaceFlowParameters,
)


class SimulationBuilder:
"""Builder for creating Simulation objects with different provider configurations."""

_ALLOWED_RESUME_SURFACE_FLOW_CHANGES = {
"cfl",
"theta",
"dtmax",
"slope_threshold",
"max_slope",
}

def __init__(
self,
sim_config: SimulationConfig,
Expand Down Expand Up @@ -127,6 +139,7 @@ def _validate_hotstart_congruence(self) -> None:

hotstart_domain = self.hotstart_loader.get_domain_data()
hotstart_config = self.hotstart_loader.get_simulation_config()
hotstart_state = self.hotstart_loader.get_simulation_state()

# Validate domain metadata
self._validate_domain_congruence(hotstart_domain)
Expand All @@ -137,6 +150,63 @@ def _validate_hotstart_congruence(self) -> None:
# Validate drainage expectations
self._validate_drainage_congruence(hotstart_config)

# Validate resume-time configuration compatibility
self._validate_resume_config_congruence(hotstart_config, hotstart_state)

def _validate_resume_config_congruence(
self,
hotstart_config: SimulationConfig,
hotstart_state: HotstartSimulationState,
) -> None:
"""Validate which runtime settings may change across a hotstart resume."""
hotstart_sim_time = datetime.fromisoformat(hotstart_state.sim_time)

# Keep this defensive check here even though SimulationConfig also validates
# user input: model_copy(update=...) can bypass Pydantic validation in tests
# and internal resume flows.
if self.sim_config.record_step <= timedelta(0):
raise HotstartError(
f"Resume record_step must be positive, not {self.sim_config.record_step}"
)

if (
self.sim_config.end_time != hotstart_config.end_time
and self.sim_config.end_time <= hotstart_sim_time
):
raise HotstartError(
"Resume end_time must be strictly after the hotstart simulation time: "
f"end_time={self.sim_config.end_time}, hotstart_sim_time={hotstart_sim_time}"
)

if self.sim_config.infiltration_model != hotstart_config.infiltration_model:
raise HotstartError(
"Hotstart infiltration model mismatch: "
f"current={self.sim_config.infiltration_model}, "
f"hotstart={hotstart_config.infiltration_model}"
)

self._validate_surface_flow_parameter_congruence(hotstart_config.surface_flow_parameters)

def _validate_surface_flow_parameter_congruence(
self,
hotstart_surface_flow_parameters: SurfaceFlowParameters,
) -> None:
"""Validate the subset of surface-flow parameters that must not change."""
current_surface_flow_parameters = self.sim_config.surface_flow_parameters

for field_name in type(current_surface_flow_parameters).model_fields:
if field_name in self._ALLOWED_RESUME_SURFACE_FLOW_CHANGES:
continue

current_value = getattr(current_surface_flow_parameters, field_name)
hotstart_value = getattr(hotstart_surface_flow_parameters, field_name)

if not np.isclose(current_value, hotstart_value):
raise HotstartError(
"Surface flow parameter mismatch for "
f"{field_name}: current={current_value}, hotstart={hotstart_value}"
)

def _validate_domain_congruence(self, hotstart_domain: DomainData) -> None:
"""Validate that domain metadata matches between hotstart and builder."""
assert self.domain_data is not None # Already validated in build()
Expand Down Expand Up @@ -315,6 +385,7 @@ def build(self) -> Simulation:
# Restore simulation runtime/scheduler state
simulation_state = self.hotstart_loader.get_simulation_state()
simulation.restore_state(simulation_state)
simulation.reconcile_hotstart_resume(self.hotstart_loader.get_simulation_config())

# Restore drainage coupling state from the saved n_drain raster.
# DrainageNode.coupling_flow is always 0 after object creation, and
Expand Down
Loading
Loading