diff --git a/nam/train/full.py b/nam/train/full.py index 862bac9c..939609da 100644 --- a/nam/train/full.py +++ b/nam/train/full.py @@ -3,6 +3,7 @@ # Author: Enrico Schifano (eraz1997@live.it) import json as _json +import pickle as _pickle from pathlib import Path as _Path from time import time as _time from typing import Optional as _Optional @@ -262,13 +263,33 @@ def main( print("\nTraining interrupted by user.") finally: # Always try to export a model, even if training was interrupted - # Go to best checkpoint + # Go to best checkpoint. Fall back to the in-memory model if loading + # fails, e.g. when SIGINT lands while a checkpoint is being written: + # Lightning sets `best_model_path` before the file is finished, so the + # path can reference a missing or partial file (see issue #645). best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = lightning_cls.load_from_checkpoint( - trainer.checkpoint_callback.best_model_path, - **lightning_cls.parse_config(model_config), - ) + try: + model = lightning_cls.load_from_checkpoint( + best_checkpoint, + **lightning_cls.parse_config(model_config), + ) + # A SIGINT during the checkpoint write leaves the file missing or + # partially-written, so torch.load can fail a few ways: missing + # file (FileNotFoundError), truncated zip archive (RuntimeError + # from torch's PyTorchStreamReader), empty file (EOFError), or + # garbage pickle data (pickle.UnpicklingError). + except ( + FileNotFoundError, + RuntimeError, + EOFError, + _pickle.UnpicklingError, + ) as e: + _warn( + f"Failed to load best checkpoint {best_checkpoint!r} " + f"({type(e).__name__}: {e}); exporting the in-memory model " + f"instead." + ) model.cpu() model.eval() model.net.sample_rate = dataset_train.sample_rate diff --git a/tests/test_graceful_shutdown.py b/tests/test_graceful_shutdown.py index fb674044..3cf66af2 100644 --- a/tests/test_graceful_shutdown.py +++ b/tests/test_graceful_shutdown.py @@ -502,6 +502,66 @@ def test_graceful_shutdown_generates_model(self): ) +class TestExportFallbackOnBadCheckpoint: + """ + Regression for issue #645. + + Lightning's ``ModelCheckpoint`` updates ``best_model_path`` BEFORE the + underlying file write completes. If SIGINT lands between those steps, + ``best_model_path`` references a missing or partial file. The finally + block in ``nam.train.full.main`` must still produce a ``.nam`` by + falling back to the in-memory model. + + The integration test above (``TestGracefulShutdown``) racily reproduces + this when the runner is slow enough that the first checkpoint save is + in flight when SIGINT arrives; this test exercises the same fallback + path deterministically by corrupting every checkpoint write. + """ + + def test_main_exports_nam_when_best_checkpoint_unloadable( + self, monkeypatch, tmp_path + ): + from lightning_fabric.plugins.io import torch_io + + from nam.train.full import main as nam_full_main + + x_path, y_path = create_test_data(tmp_path) + data_config_path, model_config_path, learning_config_path = create_configs( + tmp_path, x_path, y_path, num_epochs=1 + ) + with open(data_config_path) as fp: + data_config = json.load(fp) + with open(model_config_path) as fp: + model_config = json.load(fp) + with open(learning_config_path) as fp: + learning_config = json.load(fp) + + # Simulate SIGINT-mid-save: Lightning's `best_model_path` ends up + # set, but the file on disk is unreadable as a checkpoint. + def corrupt_save(checkpoint, filepath): + with open(filepath, "wb") as f: + f.write(b"corrupt") + + monkeypatch.setattr(torch_io, "_atomic_save", corrupt_save) + + outdir = tmp_path / "outputs" + outdir.mkdir() + nam_full_main( + data_config, + model_config, + learning_config, + outdir, + no_show=True, + make_plots=False, + ) + + nam_files = list(outdir.rglob("*.nam")) + assert nam_files, ( + f"Expected a .nam exported via in-memory fallback when the best " + f"checkpoint can't be loaded, but found none in {outdir}." + ) + + def main(): """Run the graceful shutdown test.""" import argparse