Skip to content

Elastic training: resilient multi-GPU training on preemptible partitions #3

@mmshad

Description

@mmshad

Problem

On SLURM preemptible partitions, preemption kills the entire job. If one GPU is reclaimed, all GPUs are lost and the job restarts from the last checkpoint. For long training runs this wastes significant compute.

Goal: if one GPU goes away, the remaining GPUs continue training. New GPUs can join dynamically.

Current state

  • SLURM preemption is job-level, not GPU-level
  • NCCL process groups are static (cannot add/remove ranks)
  • FSDP2 shards assume fixed world size
  • DCP checkpoint resharding supports save with N GPUs, load with M GPUs
  • Auto-resume from checkpoint is already implemented

Path A: Fast restart with orchestrator

A master process on a non-preemptible partition orchestrates worker jobs on preemptible partitions.

Architecture:

  • Master job on a CPU partition with long walltime (days), lightweight (no GPU needed)
  • Workers submitted as a single multi-GPU job on a preemptible partition
  • Master monitors worker job status via SLURM polling or scontrol

On preemption:

  1. Master detects worker job was preempted (SLURM state change)
  2. Master re-submits workers, possibly with fewer GPUs if resources are scarce
  3. Workers auto-resume from last checkpoint with DCP resharding to new world size
  4. Master can also submit additional workers when more GPUs become available

Pros: Works with current synchronous training, no algorithm changes
Cons: Still loses all workers on preemption (full restart), recovery takes 30-60s for checkpoint load

Items to build:

  • Orchestrator script that monitors SLURM job state and re-submits
  • Config for min/max GPU count and target partition
  • Adaptive world size selection based on available resources
  • Fast checkpoint interval tuning (more frequent saves on preemptible partitions)

Path B: Local SGD with Redis coordination

Each GPU is an independent single-GPU SLURM job on a preemptible partition. Workers train independently for K steps, then synchronize via a coordination service. If a worker dies, others continue without interruption.

Architecture:

  • Master on a CPU partition (long-lived, manages coordination)
  • Redis server on master node for pub/sub and state
  • Workers are independent single-GPU jobs on a preemptible partition
  • Start with minimum required workers (e.g., 8), add more as GPUs become available

Training algorithm (Local SGD / DiLoCo style):

  1. Each worker pulls current global parameters from Redis (or shared filesystem)
  2. Worker trains independently for K local steps (e.g., K=100)
  3. Worker publishes local gradient delta to Redis
  4. Master (or workers via pub/sub) averages deltas from all active workers
  5. Global parameters updated, workers pull new params for next round

Redis role:

  • Pub/sub: workers announce join/leave, master broadcasts sync signals
  • Key-value: global model state, worker heartbeats, training metadata
  • Coordination: barrier-like sync for gradient averaging rounds
  • Worker registry: master tracks active workers, detects failures via heartbeat timeout

On worker preemption:

  • Other workers continue their local steps uninterrupted
  • Next sync round proceeds with fewer workers (weighted by steps contributed)
  • Master submits replacement worker job when resources free up
  • New worker pulls latest global params from Redis and starts training immediately

On adding workers:

  • Master submits new single-GPU jobs as resources become available
  • New worker registers with Redis, pulls current params, joins next sync round
  • No restart or re-initialization needed for existing workers

Pros: True elasticity, no full restarts, workers are independent SLURM jobs (preemption only affects one GPU), naturally fits preemptible GPU scheduling
Cons: Different convergence properties than synchronous SGD (need to validate), requires Redis infrastructure, more complex than standard distributed training, gradient staleness with slow workers

Items to build:

  • Redis-based coordination layer (worker registry, heartbeat, pub/sub)
  • Local SGD training loop (K local steps, then sync)
  • Gradient delta computation and averaging
  • Worker launcher: master submits single-GPU jobs, monitors, replaces on failure
  • Parameter distribution: efficient broadcast of global params to workers via shared filesystem
  • Convergence validation: compare Local SGD loss curves against synchronous baseline
  • Fallback: if only 1 worker remains, continue training (reduced throughput but no stop)

Considerations

  • Checkpoint strategy for Path B: Each worker saves local state, global params on shared filesystem. Recovery is just "pull latest global params."
  • Gradient compression: For Path B, gradient deltas can be large. Consider TopK sparsification or low-rank compression before publishing to Redis.
  • Hivemind library: Open-source implementation of decentralized training with dynamic peers. Worth evaluating before building from scratch.
  • Mixed approach: Use Path A for large synchronized jobs (when you need all GPUs tight), Path B for opportunistic training (maximize utilization across preemptible GPUs).
  • Monitoring: Both paths need a dashboard showing active workers, sync rounds, throughput, and preemption events.

Priority

Medium-term. Current checkpoint-and-resume works for jobs that complete within walltime. This becomes critical when running multi-day training on preemptible partitions where preemption frequency is high.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions