From 47c977633d2fc8040208b23fcc39563d1fb32575 Mon Sep 17 00:00:00 2001 From: Elias Gomes Date: Mon, 4 May 2026 22:04:07 +0200 Subject: [PATCH] [FEATURE] Add epoch-level progress bar to full trainer Add a second Rich progress bar showing overall epoch progress below the existing per-batch bar. Shows elapsed time, remaining time estimate, and epoch count using Lightning's N/max-1 convention. - Extend RichProgressBar with _EpochProgressBar callback - Add _EpochCountColumn for epoch count display - Use transient Rich Progress for clean terminal output on both normal completion and Ctrl+C interruption - Remove v_num from displayed metrics - Add unit tests for both new classes --- nam/train/full.py | 104 +++++++++++++++- tests/test_nam/test_train/test_full.py | 160 +++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 tests/test_nam/test_train/test_full.py diff --git a/nam/train/full.py b/nam/train/full.py index 862bac9c..4f697a26 100644 --- a/nam/train/full.py +++ b/nam/train/full.py @@ -13,9 +13,18 @@ import numpy as _np import pytorch_lightning as _pl import torch as _torch +from pytorch_lightning.callbacks import RichProgressBar as _RichProgressBar from pytorch_lightning.utilities.warnings import ( PossibleUserWarning as _PossibleUserWarning, ) +from rich.progress import BarColumn as _BarColumn +from rich.progress import Progress as _Progress +from rich.progress import ProgressColumn as _ProgressColumn +from rich.progress import Task as _Task +from rich.progress import TextColumn as _TextColumn +from rich.progress import TimeElapsedColumn as _TimeElapsedColumn +from rich.progress import TimeRemainingColumn as _TimeRemainingColumn +from rich.text import Text as _Text from torch.utils.data import DataLoader as _DataLoader from nam.data import AbstractDataset as _AbstractDataset @@ -36,6 +45,88 @@ def _handshake_datasets(model, *datasets: _AbstractDataset) -> None: model.net.handshake(dataset) +class _EpochCountColumn(_ProgressColumn): + """Shows epoch count matching Lightning's 'Epoch N/max-1' convention.""" + + def render(self, task: "_Task") -> _Text: + total = int(task.total) - 1 if task.total else 0 + current = min(int(task.completed), total) + return _Text(f"{current}/{total}") + + +class _EpochProgressBar(_RichProgressBar): + """Rich progress bar with current epoch on top, total epochs below.""" + + def __init__(self): + super().__init__(leave=False) + self._epoch_progress = None + self._epoch_task_id = None + + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + if self.progress is not None: + self.progress.live.transient = True + if trainer.max_epochs is not None: + self._epoch_progress = _Progress( + _TextColumn("[progress.description]{task.description}"), + _BarColumn( + complete_style=self.theme.progress_bar, + finished_style=self.theme.progress_bar_finished, + ), + _EpochCountColumn(), + _TimeElapsedColumn(), + _TextColumn("•"), + _TimeRemainingColumn(), + transient=True, + ) + self._epoch_task_id = self._epoch_progress.add_task( + "Epochs", total=trainer.max_epochs + ) + self._epoch_progress.start() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + if self._epoch_progress is not None and self._epoch_task_id is not None: + fraction = (batch_idx + 1) / self.total_train_batches + self._epoch_progress.update( + self._epoch_task_id, + completed=trainer.current_epoch + fraction, + ) + + def on_train_epoch_end(self, trainer, pl_module): + super().on_train_epoch_end(trainer, pl_module) + if self._epoch_progress is not None and self._epoch_task_id is not None: + self._epoch_progress.update( + self._epoch_task_id, completed=trainer.current_epoch + 1 + ) + + def _stop_epoch_progress(self, leave=False): + if self._epoch_progress is not None: + if leave: + self._epoch_progress.live.transient = False + self._epoch_progress.stop() + + def on_train_end(self, trainer, pl_module): + self._stop_progress() + self._stop_epoch_progress(leave=True) + + def on_exception(self, trainer, pl_module, exception): + self._stop_progress() + self._stop_epoch_progress() + + def teardown(self, trainer, pl_module, stage): + self._stop_progress() + self._stop_epoch_progress() + + def _get_train_description(self, current_epoch): + return f"Epoch {current_epoch}" + + def get_metrics(self, trainer, pl_module): + items = super().get_metrics(trainer, pl_module) + items.pop("v_num", None) + return items + + def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float: if isinstance(x, _np.ndarray): return _np.sqrt(_np.mean(_np.square(x))) @@ -244,10 +335,19 @@ def main( (c for c in callbacks if isinstance(c, _lightning_module.PackedBestCheckpoint)), None, ) + trainer_config = learning_config["trainer"] + # Only attach our epoch progress bar when the trainer hasn't disabled + # progress bars (Lightning rejects a progress-bar callback alongside + # enable_progress_bar=False). + progress_bar_callbacks = ( + [_EpochProgressBar()] + if trainer_config.get("enable_progress_bar", True) + else [] + ) trainer = _pl.Trainer( - callbacks=callbacks, + callbacks=[*callbacks, *progress_bar_callbacks], default_root_dir=outdir, - **learning_config["trainer"], + **trainer_config, ) try: diff --git a/tests/test_nam/test_train/test_full.py b/tests/test_nam/test_train/test_full.py new file mode 100644 index 00000000..4366da87 --- /dev/null +++ b/tests/test_nam/test_train/test_full.py @@ -0,0 +1,160 @@ +# File: test_full.py +# Created Date: Sunday May 3rd 2026 +# Author: Elias Gomes + +from unittest.mock import MagicMock as _MagicMock + +import pytest as _pytest + +from nam.train.full import _EpochCountColumn, _EpochProgressBar + + +class TestEpochCountColumn: + @staticmethod + def _make_task(completed, total): + task = _MagicMock() + task.completed = completed + task.total = total + return task + + def test_mid_epoch(self): + col = _EpochCountColumn() + assert col.render(self._make_task(3.5, 100)).plain == "3/99" + + def test_epoch_end(self): + col = _EpochCountColumn() + assert col.render(self._make_task(4, 100)).plain == "4/99" + + def test_training_complete(self): + col = _EpochCountColumn() + assert col.render(self._make_task(100, 100)).plain == "99/99" + + +class TestEpochProgressBar: + def test_init(self): + bar = _EpochProgressBar() + assert bar._epoch_progress is None + assert bar._epoch_task_id is None + assert bar._leave is False + + @staticmethod + def _make_trainer_mock(**pbar_metrics): + trainer = _MagicMock() + trainer.progress_bar_metrics = pbar_metrics + trainer.state.fn = None + trainer.loggers = [_MagicMock(version=0)] + return trainer + + def test_get_metrics_removes_v_num(self): + bar = _EpochProgressBar() + trainer = self._make_trainer_mock(val_loss=0.1) + metrics = bar.get_metrics(trainer, _MagicMock()) + assert "v_num" not in metrics + + def test_get_metrics_preserves_other_keys(self): + bar = _EpochProgressBar() + trainer = self._make_trainer_mock(val_loss=0.5, ESR=0.01) + metrics = bar.get_metrics(trainer, _MagicMock()) + assert "val_loss" in metrics + assert "ESR" in metrics + + def test_train_description_omits_total(self): + bar = _EpochProgressBar() + assert bar._get_train_description(5) == "Epoch 5" + + def test_on_train_start_creates_epoch_progress(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 100 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + + assert bar._epoch_progress is not None + assert bar._epoch_task_id is not None + + def test_on_train_start_skips_when_max_epochs_is_none(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = None + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + + assert bar._epoch_progress is None + assert bar._epoch_task_id is None + + def test_on_train_epoch_end_updates_completed(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 10 + trainer.current_epoch = 3 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + bar.on_train_epoch_end(trainer, pl_module) + + task = bar._epoch_progress.tasks[bar._epoch_task_id] + assert task.completed == 4 + + def test_stop_epoch_progress_transient(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 10 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + assert bar._epoch_progress.live.is_started + + bar._stop_epoch_progress() + assert not bar._epoch_progress.live.is_started + assert bar._epoch_progress.live.transient + + def test_stop_epoch_progress_leave(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 10 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + bar._stop_epoch_progress(leave=True) + assert not bar._epoch_progress.live.transient + + def test_stop_epoch_progress_noop_when_none(self): + bar = _EpochProgressBar() + bar._stop_epoch_progress() + + def test_on_train_end_leaves_epoch_bar(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 10 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + bar.on_train_end(trainer, pl_module) + assert not bar._epoch_progress.live.is_started + assert not bar._epoch_progress.live.transient + assert bar._progress_stopped + + def test_on_exception_stops_all(self): + bar = _EpochProgressBar() + trainer = _MagicMock() + trainer.max_epochs = 10 + pl_module = _MagicMock() + + bar._init_progress(trainer) + bar.on_train_start(trainer, pl_module) + bar.on_exception(trainer, pl_module, RuntimeError("test")) + assert not bar._epoch_progress.live.is_started + assert bar._epoch_progress.live.transient + assert bar._progress_stopped + + +if __name__ == "__main__": + _pytest.main()