diff --git a/a3fe/_version.py b/a3fe/_version.py index 3d26edf..df12433 100644 --- a/a3fe/_version.py +++ b/a3fe/_version.py @@ -1 +1 @@ -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 90d01a2..fc979c9 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -117,9 +117,13 @@ class SomdConfig(_EngineConfig): ### Integrator - ncycles modified as required by a3fe ### timestep: float = _Field(4.0, description="Timestep in femtoseconds(fs)") + nmoves: int = _Field( + 25000, + description="Number of moves per cycle. Default 25000 provides optimal checkpoint frequency.", + ) runtime: _Union[int, float] = _Field( 5.0, - description="Runtime in nanoseconds(ns), and must be a multiple of timestep", + description="Runtime in nanoseconds(ns), must be a multiple of timestep and ncycles will be calculated from runtime and nmoves", ) ### Constraints ### @@ -225,29 +229,37 @@ class SomdConfig(_EngineConfig): ) @property - def nmoves(self) -> int: + def ncycles(self) -> int: """ - Make sure runtime is a multiple of timestep + Calculate number of cycles from runtime, nmoves and timestep. + Formula: runtime = nmoves × ncycles × timestep """ # Convert runtime to femtoseconds (ns -> fs) runtime_fs = _Decimal(str(self.runtime)) * _Decimal("1_000_000") timestep = _Decimal(str(self.timestep)) + nmoves_dec = _Decimal(str(self.nmoves)) # Check if runtime is a multiple of timestep remainder = runtime_fs % timestep if round(float(remainder), 4) != 0: raise ValueError( - ( - "Runtime must be a multiple of the timestep. " - f"Runtime is {self.runtime} ns ({runtime_fs} fs), " - f"and timestep is {self.timestep} fs." - ) + f"Runtime must be a multiple of timestep. " + f"Runtime is {self.runtime} ns ({runtime_fs} fs), " + f"timestep is {self.timestep} fs." ) - # Calculate the number of moves - nmoves = round(float(runtime_fs) / float(timestep)) + # Calculate ncycles + total_steps = runtime_fs / timestep + ncycles = total_steps / nmoves_dec + ncycles_int = round(float(ncycles)) + + if ncycles_int < 1: + raise ValueError( + f"Runtime {self.runtime} ns is too short for nmoves={self.nmoves}. " + f"Decrease nmoves or increase runtime." + ) - return nmoves + return ncycles_int @_model_validator(mode="after") def _check_rf_dielectric(self): @@ -336,6 +348,7 @@ def write_config( config_lines = [ "### Integrator ###", f"timestep = {self.timestep} * femtosecond", + f"ncycles = {self.ncycles}", f"nmoves = {self.nmoves}", f"constraint = {self.constraint}", f"hydrogen mass repartitioning factor = {self.hydrogen_mass_factor}", diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 705634d..5a0093a 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -106,6 +106,44 @@ def test_charge_cutoff_validation(engine_config, charge, cutoff, should_pass): engine_config(ligand_charge=charge, cutoff_type=cutoff, runtime=1) +@pytest.mark.parametrize( + "runtime,nmoves,timestep,calculated_ncycles", + [ + (5.0, 25000, 4.0, 50), + (10.0, 50000, 2.0, 100), + ], +) +def test_ncycles_calculation( + somd_engine_config, runtime, nmoves, timestep, calculated_ncycles +): + """Test that ncycles is correctly calculated from runtime, nmoves and timestep.""" + config = somd_engine_config(runtime=runtime, nmoves=nmoves, timestep=timestep) + assert config.ncycles == calculated_ncycles + + +def test_ncycles_invalid_runtime(somd_engine_config): + """Test that ValueError is raised when runtime is not a multiple of timestep.""" + config = somd_engine_config(runtime=5.0, nmoves=25000, timestep=3.0) + with pytest.raises(ValueError, match="Runtime must be a multiple of timestep"): + config.ncycles + + +def test_ncycles_too_short(somd_engine_config): + """Test that ValueError is raised when runtime is too short.""" + config = somd_engine_config(runtime=0.01, nmoves=25000, timestep=4.0) + with pytest.raises(ValueError, match="too short"): + config.ncycles + + +def test_ncycles_updates_on_runtime_change(somd_engine_config): + """Test that ncycles updates when runtime is changed (SSOT consistency).""" + config = somd_engine_config(runtime=5.0, nmoves=25000, timestep=4.0) + assert config.ncycles == 50 + + config.runtime = 10.0 + assert config.ncycles == 100 + + def test_ligand_charge_validation(engine_config): """Test that ligand charge validation works correctly.""" diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index da42034..d627021 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -2,6 +2,10 @@ Change Log =============== +0.4.2 +==================== +- Added nmoves as a configurable field and changed ncycles to a computed property in engine_config.py to prevent memory overflow from single-cycle runtimes. + 0.4.1 ==================== - Fixed the statistical inefficiency timestep units from femtoseconds to nanoseconds.