Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions nam/train/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Enrico Schifano (eraz1997@live.it)

import json as _json
import pickle as _pickle
from pathlib import Path as _Path
from time import time as _time
from typing import Optional as _Optional
Expand Down Expand Up @@ -262,13 +263,33 @@ def main(
print("\nTraining interrupted by user.")
finally:
# Always try to export a model, even if training was interrupted
# Go to best checkpoint
# Go to best checkpoint. Fall back to the in-memory model if loading
# fails, e.g. when SIGINT lands while a checkpoint is being written:
# Lightning sets `best_model_path` before the file is finished, so the
# path can reference a missing or partial file (see issue #645).
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
model = lightning_cls.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
**lightning_cls.parse_config(model_config),
)
try:
model = lightning_cls.load_from_checkpoint(
best_checkpoint,
**lightning_cls.parse_config(model_config),
)
# A SIGINT during the checkpoint write leaves the file missing or
# partially-written, so torch.load can fail a few ways: missing
# file (FileNotFoundError), truncated zip archive (RuntimeError
# from torch's PyTorchStreamReader), empty file (EOFError), or
# garbage pickle data (pickle.UnpicklingError).
except (
FileNotFoundError,
RuntimeError,
EOFError,
_pickle.UnpicklingError,
) as e:
_warn(
f"Failed to load best checkpoint {best_checkpoint!r} "
f"({type(e).__name__}: {e}); exporting the in-memory model "
f"instead."
)
model.cpu()
model.eval()
model.net.sample_rate = dataset_train.sample_rate
Expand Down
60 changes: 60 additions & 0 deletions tests/test_graceful_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,66 @@ def test_graceful_shutdown_generates_model(self):
)


class TestExportFallbackOnBadCheckpoint:
"""
Regression for issue #645.

Lightning's ``ModelCheckpoint`` updates ``best_model_path`` BEFORE the
underlying file write completes. If SIGINT lands between those steps,
``best_model_path`` references a missing or partial file. The finally
block in ``nam.train.full.main`` must still produce a ``.nam`` by
falling back to the in-memory model.

The integration test above (``TestGracefulShutdown``) racily reproduces
this when the runner is slow enough that the first checkpoint save is
in flight when SIGINT arrives; this test exercises the same fallback
path deterministically by corrupting every checkpoint write.
"""

def test_main_exports_nam_when_best_checkpoint_unloadable(
self, monkeypatch, tmp_path
):
from lightning_fabric.plugins.io import torch_io

from nam.train.full import main as nam_full_main

x_path, y_path = create_test_data(tmp_path)
data_config_path, model_config_path, learning_config_path = create_configs(
tmp_path, x_path, y_path, num_epochs=1
)
with open(data_config_path) as fp:
data_config = json.load(fp)
with open(model_config_path) as fp:
model_config = json.load(fp)
with open(learning_config_path) as fp:
learning_config = json.load(fp)

# Simulate SIGINT-mid-save: Lightning's `best_model_path` ends up
# set, but the file on disk is unreadable as a checkpoint.
def corrupt_save(checkpoint, filepath):
with open(filepath, "wb") as f:
f.write(b"corrupt")

monkeypatch.setattr(torch_io, "_atomic_save", corrupt_save)

outdir = tmp_path / "outputs"
outdir.mkdir()
nam_full_main(
data_config,
model_config,
learning_config,
outdir,
no_show=True,
make_plots=False,
)

nam_files = list(outdir.rglob("*.nam"))
assert nam_files, (
f"Expected a .nam exported via in-memory fallback when the best "
f"checkpoint can't be loaded, but found none in {outdir}."
)


def main():
"""Run the graceful shutdown test."""
import argparse
Expand Down