Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions nam/train/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
import numpy as _np
import pytorch_lightning as _pl
import torch as _torch
from pytorch_lightning.callbacks import RichProgressBar as _RichProgressBar
from pytorch_lightning.utilities.warnings import (
PossibleUserWarning as _PossibleUserWarning,
)
from rich.progress import BarColumn as _BarColumn
from rich.progress import Progress as _Progress
from rich.progress import ProgressColumn as _ProgressColumn
from rich.progress import Task as _Task
from rich.progress import TextColumn as _TextColumn
from rich.progress import TimeElapsedColumn as _TimeElapsedColumn
from rich.progress import TimeRemainingColumn as _TimeRemainingColumn
from rich.text import Text as _Text
from torch.utils.data import DataLoader as _DataLoader

from nam.data import AbstractDataset as _AbstractDataset
Expand All @@ -36,6 +45,88 @@ def _handshake_datasets(model, *datasets: _AbstractDataset) -> None:
model.net.handshake(dataset)


class _EpochCountColumn(_ProgressColumn):
"""Shows epoch count matching Lightning's 'Epoch N/max-1' convention."""

def render(self, task: "_Task") -> _Text:
total = int(task.total) - 1 if task.total else 0
current = min(int(task.completed), total)
return _Text(f"{current}/{total}")


class _EpochProgressBar(_RichProgressBar):
"""Rich progress bar with current epoch on top, total epochs below."""

def __init__(self):
super().__init__(leave=False)
self._epoch_progress = None
self._epoch_task_id = None

def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
if self.progress is not None:
self.progress.live.transient = True
if trainer.max_epochs is not None:
self._epoch_progress = _Progress(
_TextColumn("[progress.description]{task.description}"),
_BarColumn(
complete_style=self.theme.progress_bar,
finished_style=self.theme.progress_bar_finished,
),
_EpochCountColumn(),
_TimeElapsedColumn(),
_TextColumn("•"),
_TimeRemainingColumn(),
transient=True,
)
self._epoch_task_id = self._epoch_progress.add_task(
"Epochs", total=trainer.max_epochs
)
self._epoch_progress.start()

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
if self._epoch_progress is not None and self._epoch_task_id is not None:
fraction = (batch_idx + 1) / self.total_train_batches
self._epoch_progress.update(
self._epoch_task_id,
completed=trainer.current_epoch + fraction,
)

def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
if self._epoch_progress is not None and self._epoch_task_id is not None:
self._epoch_progress.update(
self._epoch_task_id, completed=trainer.current_epoch + 1
)

def _stop_epoch_progress(self, leave=False):
if self._epoch_progress is not None:
if leave:
self._epoch_progress.live.transient = False
self._epoch_progress.stop()

def on_train_end(self, trainer, pl_module):
self._stop_progress()
self._stop_epoch_progress(leave=True)

def on_exception(self, trainer, pl_module, exception):
self._stop_progress()
self._stop_epoch_progress()

def teardown(self, trainer, pl_module, stage):
self._stop_progress()
self._stop_epoch_progress()

def _get_train_description(self, current_epoch):
return f"Epoch {current_epoch}"

def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return items


def _rms(x: _Union[_np.ndarray, _torch.Tensor]) -> float:
if isinstance(x, _np.ndarray):
return _np.sqrt(_np.mean(_np.square(x)))
Expand Down Expand Up @@ -244,10 +335,19 @@ def main(
(c for c in callbacks if isinstance(c, _lightning_module.PackedBestCheckpoint)),
None,
)
trainer_config = learning_config["trainer"]
# Only attach our epoch progress bar when the trainer hasn't disabled
# progress bars (Lightning rejects a progress-bar callback alongside
# enable_progress_bar=False).
progress_bar_callbacks = (
[_EpochProgressBar()]
if trainer_config.get("enable_progress_bar", True)
else []
)
trainer = _pl.Trainer(
callbacks=callbacks,
callbacks=[*callbacks, *progress_bar_callbacks],
default_root_dir=outdir,
**learning_config["trainer"],
**trainer_config,
)

try:
Expand Down
160 changes: 160 additions & 0 deletions tests/test_nam/test_train/test_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# File: test_full.py
# Created Date: Sunday May 3rd 2026
# Author: Elias Gomes

from unittest.mock import MagicMock as _MagicMock

import pytest as _pytest

from nam.train.full import _EpochCountColumn, _EpochProgressBar


class TestEpochCountColumn:
@staticmethod
def _make_task(completed, total):
task = _MagicMock()
task.completed = completed
task.total = total
return task

def test_mid_epoch(self):
col = _EpochCountColumn()
assert col.render(self._make_task(3.5, 100)).plain == "3/99"

def test_epoch_end(self):
col = _EpochCountColumn()
assert col.render(self._make_task(4, 100)).plain == "4/99"

def test_training_complete(self):
col = _EpochCountColumn()
assert col.render(self._make_task(100, 100)).plain == "99/99"


class TestEpochProgressBar:
def test_init(self):
bar = _EpochProgressBar()
assert bar._epoch_progress is None
assert bar._epoch_task_id is None
assert bar._leave is False

@staticmethod
def _make_trainer_mock(**pbar_metrics):
trainer = _MagicMock()
trainer.progress_bar_metrics = pbar_metrics
trainer.state.fn = None
trainer.loggers = [_MagicMock(version=0)]
return trainer

def test_get_metrics_removes_v_num(self):
bar = _EpochProgressBar()
trainer = self._make_trainer_mock(val_loss=0.1)
metrics = bar.get_metrics(trainer, _MagicMock())
assert "v_num" not in metrics

def test_get_metrics_preserves_other_keys(self):
bar = _EpochProgressBar()
trainer = self._make_trainer_mock(val_loss=0.5, ESR=0.01)
metrics = bar.get_metrics(trainer, _MagicMock())
assert "val_loss" in metrics
assert "ESR" in metrics

def test_train_description_omits_total(self):
bar = _EpochProgressBar()
assert bar._get_train_description(5) == "Epoch 5"

def test_on_train_start_creates_epoch_progress(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 100
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)

assert bar._epoch_progress is not None
assert bar._epoch_task_id is not None

def test_on_train_start_skips_when_max_epochs_is_none(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = None
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)

assert bar._epoch_progress is None
assert bar._epoch_task_id is None

def test_on_train_epoch_end_updates_completed(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 10
trainer.current_epoch = 3
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)
bar.on_train_epoch_end(trainer, pl_module)

task = bar._epoch_progress.tasks[bar._epoch_task_id]
assert task.completed == 4

def test_stop_epoch_progress_transient(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 10
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)
assert bar._epoch_progress.live.is_started

bar._stop_epoch_progress()
assert not bar._epoch_progress.live.is_started
assert bar._epoch_progress.live.transient

def test_stop_epoch_progress_leave(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 10
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)
bar._stop_epoch_progress(leave=True)
assert not bar._epoch_progress.live.transient

def test_stop_epoch_progress_noop_when_none(self):
bar = _EpochProgressBar()
bar._stop_epoch_progress()

def test_on_train_end_leaves_epoch_bar(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 10
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)
bar.on_train_end(trainer, pl_module)
assert not bar._epoch_progress.live.is_started
assert not bar._epoch_progress.live.transient
assert bar._progress_stopped

def test_on_exception_stops_all(self):
bar = _EpochProgressBar()
trainer = _MagicMock()
trainer.max_epochs = 10
pl_module = _MagicMock()

bar._init_progress(trainer)
bar.on_train_start(trainer, pl_module)
bar.on_exception(trainer, pl_module, RuntimeError("test"))
assert not bar._epoch_progress.live.is_started
assert bar._epoch_progress.live.transient
assert bar._progress_stopped


if __name__ == "__main__":
_pytest.main()