Skip to content

Automatic recovery from training instability #2

@mmshad

Description

@mmshad

Context

KempnerForge detects training instability (NaN/Inf loss) but only halts the run after a consecutive-NaN threshold. There is no automatic recovery path, so every instability event requires manual intervention. For long-running production jobs that's expensive.

Current behavior:

  • NaNDetector.check_loss in kempnerforge/resilience/health.py syncs a NaN flag across ranks via a small all_reduce before returning.
  • On NaN, the training loop in scripts/train.py zeros grads, skips the optimizer step, and continues. If consecutive_nans >= max_consecutive, it logs "Too many consecutive NaNs — stopping" and breaks out of the loop.
  • The detector is constructed in scripts/train.py with hardcoded action="warn" and max_consecutive=10 — not exposed via TOML.
  • NaNDetector.should_rollback is a property flag, not an automated action; the loop just stops when it trips.
  • NaNDetector.check_gradients helper exists but is not called from the training loop (noted in docs/resilience/nan-detection.md).
  • NCCL health check defaults to disabled: nccl_health_check_interval: int = 0 in TrainConfig.

Items to consider

  • Automatic checkpoint rollback: On should_rollback, drop to the last known-good checkpoint (or N steps back) and resume instead of breaking.
  • LR reduction on rollback: Temporarily reduce LR (e.g., ×0.5) after rollback and ramp back over N steps.
  • Gradient anomaly detection in the hot path: Wire NaNDetector.check_gradients (or a cheaper grad-norm spike check) into the training loop so explosions are caught before they reach loss.
  • OOM recovery with batch size reduction: Catch CUDA OOM, reduce micro-batch size or grad-accum for one step, retry.
  • Configurable recovery policy via TOML: Expose NaN action (warn/skip/raise), max_consecutive, rollback depth, max retries, LR reduction factor.
  • NCCL health check default: Consider a non-zero default for nccl_health_check_interval so production runs detect hung collectives without opt-in.
  • Structured recovery event log: Log each recovery action (rollback step, LR change, skip reason) for post-mortem. NaNState.nan_steps already captures NaN step indices but not recovery actions.
  • on_instability training hook: TrainingHook in kempnerforge/training/hooks.py currently exposes on_train_begin, on_step_end, on_eval_end, on_checkpoint_save, and on_train_end. Add one more for instability so researchers can plug in custom recovery.

Priority

Low for now. Detect-and-stop is safe and predictable; automatic recovery adds complexity and risk of silently training on bad state. Revisit when multi-day jobs make operator intervention cost too high.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions