diff --git a/docs/source/tutorials/packed-training.rst b/docs/source/tutorials/packed-training.rst index 14ad0493..10225b6d 100644 --- a/docs/source/tutorials/packed-training.rst +++ b/docs/source/tutorials/packed-training.rst @@ -151,9 +151,9 @@ During training, the packed model returns predictions shaped like ``(batch, submodel, time)``. The trainer computes the configured training loss for each submodel prediction against the same target audio and sums the losses. -Validation logs aggregate metrics such as ``val_loss``, ``ESR``, and ``MRSTFT`` -as well as per-submodel metrics such as ``val_loss_packed_0``, ``ESR_packed_0``, -and ``MRSTFT_packed_0``. When +Validation logs mean aggregate metrics such as ``val_loss``, ``ESR``, and +``MRSTFT`` as well as per-submodel metrics such as ``val_loss_packed_0``, +``ESR_packed_0``, and ``MRSTFT_packed_0``. When validation is available, the trainer may also save per-submodel best checkpoints named ``packed_best_submodel_.ckpt``. diff --git a/nam/train/core.py b/nam/train/core.py index 3fb88531..1a218152 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -1034,11 +1034,10 @@ def _plot( for i in range(output.shape[0]) ] esrs = [_esr(_torch.Tensor(output_i), ds.y) for output_i in output] - aggregate_esr = sum(esrs) + validation_esr = esrs[-1] for label, esr in zip(labels, esrs): print(f"Error-signal ratio ({label}) = {esr:.4g}") - print(f"Aggregate error-signal ratio = {aggregate_esr:.4g}") - print(_esr_comment(aggregate_esr)) + print(_esr_comment(validation_esr)) _plt.figure(figsize=(16, 5)) for label, output_i, esr in zip(labels, output, esrs): @@ -1048,8 +1047,7 @@ def _plot( ) _plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") _plt.title( - "Aggregate ESR=" - f"{aggregate_esr:.4g} (" + "ESR (" + ", ".join(f"{label}: {esr:.4g}" for label, esr in zip(labels, esrs)) + ")" ) @@ -1058,7 +1056,7 @@ def _plot( _plt.savefig(filepath + ".png") if not silent: _plt.show() - return aggregate_esr + return validation_esr output = output.flatten() esr = _esr(_torch.Tensor(output), ds.y) diff --git a/nam/train/lightning_module.py b/nam/train/lightning_module.py index d31b8f50..ede215ec 100644 --- a/nam/train/lightning_module.py +++ b/nam/train/lightning_module.py @@ -522,9 +522,11 @@ def validation_step(self, batch, batch_idx): loss_dict[key].value for loss_dict in loss_dicts if key in loss_dict ] if values: - logs[key] = sum(values) - logs["ESR"] = sum(loss_dict["ESR"].value for loss_dict in loss_dicts) - logs["val_loss"] = sum(val_losses) + logs[key] = sum(values) / len(values) + logs["ESR"] = sum(loss_dict["ESR"].value for loss_dict in loss_dicts) / len( + loss_dicts + ) + logs["val_loss"] = sum(val_losses) / len(val_losses) self.log_dict(logs) return logs["val_loss"] diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py index ad490fbf..d2e28c02 100644 --- a/tests/test_nam/test_train/test_core.py +++ b/tests/test_nam/test_train/test_core.py @@ -355,6 +355,7 @@ def test_plot_reports_and_plots_each_packed_prediction(mocker, capsys): target = torch.tensor([1.0, -1.0, 2.0, -2.0]) predictions = torch.stack([target, 0.5 * target]) plot_calls = [] + titles = [] class FakeDataset: x = torch.zeros_like(target) @@ -376,15 +377,21 @@ def capture_plot(*args, **kwargs): mocker.patch.object(core, "_time", lambda: next(times)) mocker.patch("matplotlib.pyplot.figure") mocker.patch("matplotlib.pyplot.plot", capture_plot) - mocker.patch("matplotlib.pyplot.title") + mocker.patch("matplotlib.pyplot.title", lambda title: titles.append(title)) mocker.patch("matplotlib.pyplot.legend") mocker.patch("matplotlib.pyplot.savefig") mocker.patch("matplotlib.pyplot.show") - core._plot(FakeModel(), FakeDataset, silent=True) + validation_esr = core._plot(FakeModel(), FakeDataset, silent=True) stdout = capsys.readouterr().out assert stdout.count("Error-signal ratio") == 2 + assert "Aggregate error-signal ratio" not in stdout + assert validation_esr == 0.25 + assert len(titles) == 1 + assert "Aggregate ESR" not in titles[0] + assert "small" in titles[0] + assert "large" in titles[0] def as_numpy(value): if isinstance(value, torch.Tensor): diff --git a/tests/test_nam/test_train/test_lightning_module.py b/tests/test_nam/test_train/test_lightning_module.py index 1ca2e093..a2228791 100644 --- a/tests/test_nam/test_train/test_lightning_module.py +++ b/tests/test_nam/test_train/test_lightning_module.py @@ -207,7 +207,11 @@ def test_packed_lightning_validation_logs_per_submodel_and_aggregate(): assert "ESR_packed_0" in captured assert "MSE_packed_1" in captured assert _torch.allclose( - val_loss, captured["val_loss_packed_0"] + captured["val_loss_packed_1"] + val_loss, (captured["val_loss_packed_0"] + captured["val_loss_packed_1"]) / 2 + ) + assert _torch.allclose(captured["val_loss"], val_loss) + assert _torch.allclose( + captured["MSE"], (captured["MSE_packed_0"] + captured["MSE_packed_1"]) / 2 ) @@ -229,7 +233,7 @@ def test_packed_lightning_validation_logs_mrstft_per_submodel(mocker): module.validation_step((x, targets), 0) assert captured["MRSTFT_packed_0"] == 0.3 assert captured["MRSTFT_packed_1"] == 0.7 - assert captured["MRSTFT"] == 1.0 + assert captured["MRSTFT"] == 0.5 def test_packed_best_checkpoint_records_distinct_checkpoints(tmp_path):