diff --git a/nam/models/wavenet/_packed_wavenet.py b/nam/models/wavenet/_packed_wavenet.py index b023a035..3dcb26c8 100644 --- a/nam/models/wavenet/_packed_wavenet.py +++ b/nam/models/wavenet/_packed_wavenet.py @@ -160,6 +160,7 @@ def export_container( }, "weights": [], } + self._sync_container_metadata_to_highest_quality_submodel(container) if self.sample_rate is not None: container["sample_rate"] = self.sample_rate if user_metadata is not None: @@ -219,6 +220,19 @@ def _container_max_values(self) -> list[float]: raise ValueError("container_max_values must be sorted") return values + @staticmethod + def _sync_container_metadata_to_highest_quality_submodel(container: _Dict) -> None: + submodels = container["config"]["submodels"] + if len(submodels) == 0: + return + highest_quality = max(submodels, key=lambda submodel: submodel["max_value"]) + highest_quality_metadata = highest_quality["model"].get("metadata") + if not isinstance(highest_quality_metadata, dict): + return + for key in ("loudness", "gain"): + if key in highest_quality_metadata: + container["metadata"][key] = highest_quality_metadata[key] + def _normalize_checkpoint_paths(self, paths): if paths is None: return None diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py index 374f2d12..488298f0 100644 --- a/tests/test_nam/test_data.py +++ b/tests/test_nam/test_data.py @@ -352,7 +352,7 @@ def test_slimmable_container_shifts_container_and_submodels(self): scale = 0.5 container = { "architecture": "SlimmableContainer", - "metadata": {"loudness": -18.0, "gain": 0.4}, + "metadata": {"loudness": -17.0, "gain": 0.5}, "config": { "submodels": [ { @@ -378,14 +378,19 @@ def test_slimmable_container_shifts_container_and_submodels(self): } self._hook(scale).apply(container) offset = 20.0 * math.log10(scale) - assert container["metadata"]["loudness"] == pytest.approx(-18.0 + offset) assert container["config"]["submodels"][0]["model"]["metadata"][ "loudness" ] == pytest.approx(-19.0 + offset) assert container["config"]["submodels"][1]["model"]["metadata"][ "loudness" ] == pytest.approx(-17.0 + offset) - assert container["metadata"]["gain"] == 0.4 + assert container["metadata"]["loudness"] == pytest.approx( + container["config"]["submodels"][1]["model"]["metadata"]["loudness"] + ) + assert ( + container["metadata"]["gain"] + == container["config"]["submodels"][1]["model"]["metadata"]["gain"] + ) def test_no_op_when_loudness_metadata_absent(self): """Hook is robust when called on a dict without loudness metadata.""" diff --git a/tests/test_nam/test_models/test_packed_wavenet.py b/tests/test_nam/test_models/test_packed_wavenet.py index ce4c3b60..df9a0e06 100644 --- a/tests/test_nam/test_models/test_packed_wavenet.py +++ b/tests/test_nam/test_models/test_packed_wavenet.py @@ -8,6 +8,8 @@ from nam.models.wavenet import WaveNet as _WaveNet from nam.models.wavenet._packed_conv import PackedConv1dBase as _PackedConv1dBase +_DEFAULT_HEAD_SCALE = 0.25 + def _wavenet_config(channels: int, *, dilations=None, activation="Tanh"): return { @@ -23,7 +25,7 @@ def _wavenet_config(channels: int, *, dilations=None, activation="Tanh"): } ], "head": None, - "head_scale": 0.25, + "head_scale": _DEFAULT_HEAD_SCALE, } @@ -64,7 +66,7 @@ def _two_array_wavenet_config(channels_0: int, channels_1: int): }, ], "head": None, - "head_scale": 0.25, + "head_scale": _DEFAULT_HEAD_SCALE, } @@ -216,6 +218,11 @@ def test_packed_export_writes_slimmable_container(tmp_path): from_disk = _json.load(fp) assert from_disk == container _assert_container_contains_two_wavenets(container) + highest_quality = max( + container["config"]["submodels"], key=lambda entry: entry["max_value"] + )["model"] + assert container["metadata"]["loudness"] == highest_quality["metadata"]["loudness"] + assert container["metadata"]["gain"] == highest_quality["metadata"]["gain"] def test_packed_export_refreshes_loudness_after_head_scale_compensation(tmp_path): @@ -231,11 +238,14 @@ def test_packed_export_refreshes_loudness_after_head_scale_compensation(tmp_path model = _PackedWaveNet.init_from_config({**_packed_config(), "sample_rate": 48_000}) pre_container = model.export_container(tmp_path) - pre_container_loudness = pre_container["metadata"]["loudness"] pre_submodel_loudnesses = [ entry["model"]["metadata"]["loudness"] for entry in pre_container["config"]["submodels"] ] + pre_highest_quality_loudness = max( + pre_container["config"]["submodels"], + key=lambda entry: entry["max_value"], + )["model"]["metadata"]["loudness"] scale = 2.0 model.export_model_dict_post_hooks.append(_data.Dataset._ScaleOutputHook(scale=scale)) @@ -243,7 +253,7 @@ def test_packed_export_refreshes_loudness_after_head_scale_compensation(tmp_path offset_db = 20.0 * _math.log10(scale) assert post_container["metadata"]["loudness"] == _pytest.approx( - pre_container_loudness + offset_db, abs=1e-3 + pre_highest_quality_loudness + offset_db, abs=1e-3 ) for entry, pre_loudness in zip( post_container["config"]["submodels"], pre_submodel_loudnesses @@ -253,7 +263,7 @@ def test_packed_export_refreshes_loudness_after_head_scale_compensation(tmp_path ) # head_scale was actually compensated on disk assert entry["model"]["config"]["head_scale"] == _pytest.approx( - 0.25 * scale + _DEFAULT_HEAD_SCALE * scale )