From 82c1bfccfea5052bdf0d66453616c148055b01e4 Mon Sep 17 00:00:00 2001 From: Robert Haist Date: Sun, 24 May 2026 14:13:36 +0200 Subject: [PATCH 1/2] [BUGFIX] Fall back to in-memory model when best checkpoint can't be loaded Lightning's ModelCheckpoint updates best_model_path before the underlying file write completes. If training is interrupted between those steps the finally block in nam.train.full.main raises when it tries to load the partial or missing file, so no .nam is exported. Wrap the load and fall back to the in-memory model on failure. Adds a deterministic regression test alongside the existing flaky integration test. Fixes #645. --- nam/train/full.py | 20 ++++++++--- tests/test_graceful_shutdown.py | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/nam/train/full.py b/nam/train/full.py index 862bac9c..1535f0d5 100644 --- a/nam/train/full.py +++ b/nam/train/full.py @@ -262,13 +262,23 @@ def main( print("\nTraining interrupted by user.") finally: # Always try to export a model, even if training was interrupted - # Go to best checkpoint + # Go to best checkpoint. Fall back to the in-memory model if loading + # fails, e.g. when SIGINT lands while a checkpoint is being written: + # Lightning sets `best_model_path` before the file is finished, so the + # path can reference a missing or partial file (see issue #645). best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = lightning_cls.load_from_checkpoint( - trainer.checkpoint_callback.best_model_path, - **lightning_cls.parse_config(model_config), - ) + try: + model = lightning_cls.load_from_checkpoint( + best_checkpoint, + **lightning_cls.parse_config(model_config), + ) + except Exception as e: + _warn( + f"Failed to load best checkpoint {best_checkpoint!r} " + f"({type(e).__name__}: {e}); exporting the in-memory model " + f"instead." + ) model.cpu() model.eval() model.net.sample_rate = dataset_train.sample_rate diff --git a/tests/test_graceful_shutdown.py b/tests/test_graceful_shutdown.py index fb674044..3cf66af2 100644 --- a/tests/test_graceful_shutdown.py +++ b/tests/test_graceful_shutdown.py @@ -502,6 +502,66 @@ def test_graceful_shutdown_generates_model(self): ) +class TestExportFallbackOnBadCheckpoint: + """ + Regression for issue #645. + + Lightning's ``ModelCheckpoint`` updates ``best_model_path`` BEFORE the + underlying file write completes. If SIGINT lands between those steps, + ``best_model_path`` references a missing or partial file. The finally + block in ``nam.train.full.main`` must still produce a ``.nam`` by + falling back to the in-memory model. + + The integration test above (``TestGracefulShutdown``) racily reproduces + this when the runner is slow enough that the first checkpoint save is + in flight when SIGINT arrives; this test exercises the same fallback + path deterministically by corrupting every checkpoint write. + """ + + def test_main_exports_nam_when_best_checkpoint_unloadable( + self, monkeypatch, tmp_path + ): + from lightning_fabric.plugins.io import torch_io + + from nam.train.full import main as nam_full_main + + x_path, y_path = create_test_data(tmp_path) + data_config_path, model_config_path, learning_config_path = create_configs( + tmp_path, x_path, y_path, num_epochs=1 + ) + with open(data_config_path) as fp: + data_config = json.load(fp) + with open(model_config_path) as fp: + model_config = json.load(fp) + with open(learning_config_path) as fp: + learning_config = json.load(fp) + + # Simulate SIGINT-mid-save: Lightning's `best_model_path` ends up + # set, but the file on disk is unreadable as a checkpoint. + def corrupt_save(checkpoint, filepath): + with open(filepath, "wb") as f: + f.write(b"corrupt") + + monkeypatch.setattr(torch_io, "_atomic_save", corrupt_save) + + outdir = tmp_path / "outputs" + outdir.mkdir() + nam_full_main( + data_config, + model_config, + learning_config, + outdir, + no_show=True, + make_plots=False, + ) + + nam_files = list(outdir.rglob("*.nam")) + assert nam_files, ( + f"Expected a .nam exported via in-memory fallback when the best " + f"checkpoint can't be loaded, but found none in {outdir}." + ) + + def main(): """Run the graceful shutdown test.""" import argparse From ce4461586b5c9c3de1ff2e58866bfa360b562217 Mon Sep 17 00:00:00 2001 From: Robert Haist Date: Mon, 15 Jun 2026 10:20:17 +0200 Subject: [PATCH 2/2] Catch specific exceptions when best checkpoint can't be loaded Address PR review: replace bare 'except Exception' with the concrete failure modes torch.load raises on a missing or partially-written checkpoint (FileNotFoundError, RuntimeError, EOFError, pickle.UnpicklingError). Verified empirically against torch 2.12. --- nam/train/full.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/nam/train/full.py b/nam/train/full.py index 1535f0d5..939609da 100644 --- a/nam/train/full.py +++ b/nam/train/full.py @@ -3,6 +3,7 @@ # Author: Enrico Schifano (eraz1997@live.it) import json as _json +import pickle as _pickle from pathlib import Path as _Path from time import time as _time from typing import Optional as _Optional @@ -273,7 +274,17 @@ def main( best_checkpoint, **lightning_cls.parse_config(model_config), ) - except Exception as e: + # A SIGINT during the checkpoint write leaves the file missing or + # partially-written, so torch.load can fail a few ways: missing + # file (FileNotFoundError), truncated zip archive (RuntimeError + # from torch's PyTorchStreamReader), empty file (EOFError), or + # garbage pickle data (pickle.UnpicklingError). + except ( + FileNotFoundError, + RuntimeError, + EOFError, + _pickle.UnpicklingError, + ) as e: _warn( f"Failed to load best checkpoint {best_checkpoint!r} " f"({type(e).__name__}: {e}); exporting the in-memory model "