Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/tutorials/packed-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ During training, the packed model returns predictions shaped like
``(batch, submodel, time)``. The trainer computes the configured training loss
for each submodel prediction against the same target audio and sums the losses.

Validation logs aggregate metrics such as ``val_loss``, ``ESR``, and ``MRSTFT``
as well as per-submodel metrics such as ``val_loss_packed_0``, ``ESR_packed_0``,
and ``MRSTFT_packed_0``. When
Validation logs mean aggregate metrics such as ``val_loss``, ``ESR``, and
``MRSTFT`` as well as per-submodel metrics such as ``val_loss_packed_0``,
``ESR_packed_0``, and ``MRSTFT_packed_0``. When
validation is available, the trainer may also save per-submodel best
checkpoints named ``packed_best_submodel_<i>.ckpt``.

Expand Down
10 changes: 4 additions & 6 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,11 +1034,10 @@ def _plot(
for i in range(output.shape[0])
]
esrs = [_esr(_torch.Tensor(output_i), ds.y) for output_i in output]
aggregate_esr = sum(esrs)
validation_esr = esrs[-1]
for label, esr in zip(labels, esrs):
print(f"Error-signal ratio ({label}) = {esr:.4g}")
print(f"Aggregate error-signal ratio = {aggregate_esr:.4g}")
print(_esr_comment(aggregate_esr))
print(_esr_comment(validation_esr))

_plt.figure(figsize=(16, 5))
for label, output_i, esr in zip(labels, output, esrs):
Expand All @@ -1048,8 +1047,7 @@ def _plot(
)
_plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
_plt.title(
"Aggregate ESR="
f"{aggregate_esr:.4g} ("
"ESR ("
+ ", ".join(f"{label}: {esr:.4g}" for label, esr in zip(labels, esrs))
+ ")"
)
Expand All @@ -1058,7 +1056,7 @@ def _plot(
_plt.savefig(filepath + ".png")
if not silent:
_plt.show()
return aggregate_esr
return validation_esr

output = output.flatten()
esr = _esr(_torch.Tensor(output), ds.y)
Expand Down
8 changes: 5 additions & 3 deletions nam/train/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,11 @@ def validation_step(self, batch, batch_idx):
loss_dict[key].value for loss_dict in loss_dicts if key in loss_dict
]
if values:
logs[key] = sum(values)
logs["ESR"] = sum(loss_dict["ESR"].value for loss_dict in loss_dicts)
logs["val_loss"] = sum(val_losses)
logs[key] = sum(values) / len(values)
logs["ESR"] = sum(loss_dict["ESR"].value for loss_dict in loss_dicts) / len(
loss_dicts
)
logs["val_loss"] = sum(val_losses) / len(val_losses)
self.log_dict(logs)
return logs["val_loss"]

Expand Down
11 changes: 9 additions & 2 deletions tests/test_nam/test_train/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def test_plot_reports_and_plots_each_packed_prediction(mocker, capsys):
target = torch.tensor([1.0, -1.0, 2.0, -2.0])
predictions = torch.stack([target, 0.5 * target])
plot_calls = []
titles = []

class FakeDataset:
x = torch.zeros_like(target)
Expand All @@ -376,15 +377,21 @@ def capture_plot(*args, **kwargs):
mocker.patch.object(core, "_time", lambda: next(times))
mocker.patch("matplotlib.pyplot.figure")
mocker.patch("matplotlib.pyplot.plot", capture_plot)
mocker.patch("matplotlib.pyplot.title")
mocker.patch("matplotlib.pyplot.title", lambda title: titles.append(title))
mocker.patch("matplotlib.pyplot.legend")
mocker.patch("matplotlib.pyplot.savefig")
mocker.patch("matplotlib.pyplot.show")

core._plot(FakeModel(), FakeDataset, silent=True)
validation_esr = core._plot(FakeModel(), FakeDataset, silent=True)

stdout = capsys.readouterr().out
assert stdout.count("Error-signal ratio") == 2
assert "Aggregate error-signal ratio" not in stdout
assert validation_esr == 0.25
assert len(titles) == 1
assert "Aggregate ESR" not in titles[0]
assert "small" in titles[0]
assert "large" in titles[0]

def as_numpy(value):
if isinstance(value, torch.Tensor):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_nam/test_train/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,11 @@ def test_packed_lightning_validation_logs_per_submodel_and_aggregate():
assert "ESR_packed_0" in captured
assert "MSE_packed_1" in captured
assert _torch.allclose(
val_loss, captured["val_loss_packed_0"] + captured["val_loss_packed_1"]
val_loss, (captured["val_loss_packed_0"] + captured["val_loss_packed_1"]) / 2
)
assert _torch.allclose(captured["val_loss"], val_loss)
assert _torch.allclose(
captured["MSE"], (captured["MSE_packed_0"] + captured["MSE_packed_1"]) / 2
)


Expand All @@ -229,7 +233,7 @@ def test_packed_lightning_validation_logs_mrstft_per_submodel(mocker):
module.validation_step((x, targets), 0)
assert captured["MRSTFT_packed_0"] == 0.3
assert captured["MRSTFT_packed_1"] == 0.7
assert captured["MRSTFT"] == 1.0
assert captured["MRSTFT"] == 0.5


def test_packed_best_checkpoint_records_distinct_checkpoints(tmp_path):
Expand Down
Loading