Conversation
There was a problem hiding this comment.
Pull request overview
Adds MLPerf-oriented training support for Llama 3.1 8B FP8 in Primus/Megatron, including early-stop/time-to-train logging, reproducible validation sampling, optional TE fused SwiGLU, and runnable MLPerf example artifacts.
Changes:
- Add MLPerf-style early stopping + time-to-train logging driven by an eval-loss target.
- Patch Megatron validation sampling to be fixed/reproducible for MLPerf evaluation.
- Introduce MLPerf example scripts/configs (train entrypoint, run script, profiler handler, platform config, README).
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| primus/modules/trainer/megatron/trainer.py | Adds training start timestamp, early-stop on target eval loss, and time-to-train logging. |
| primus/modules/trainer/megatron/pre_trainer.py | Makes batch tuple return order explicit instead of relying on dict ordering. |
| primus/backends/megatron/training/evaluator.py | Stores eval “lm loss” onto args for downstream early-stop logic. |
| primus/backends/megatron/patches/validation_data_sampling_patches.py | Adds MLPerf patch hooks to fix validation sample counts and loader sampling behavior. |
| primus/backends/megatron/patches/te_patches/fused_bias_swiglu_patches.py | Adds optional TE fused swiglu/dswiglu patching via env flag. |
| examples/mlperf/src/train.py | Adds an MLPerf training entrypoint script. |
| examples/mlperf/src/prof_handler.py | Adds a torch profiler output/handler utility. |
| examples/mlperf/run_and_time.sh | Adds a run-and-time wrapper script for MLPerf timing runs. |
| examples/mlperf/README.md | Adds setup/run instructions for the MLPerf example. |
| examples/mlperf/configs/MI355X/llama3.1_8B-pretrain-FP8.yaml | Adds an MLPerf-oriented MI355X FP8 training config. |
| examples/mlperf/config_MI355X_1x8x1.sh | Adds MI355X 1x8x1 environment configuration for MLPerf runs. |
| enable_forward_pre_hook(model) | ||
| pre_hook_enabled = True | ||
|
|
||
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) |
There was a problem hiding this comment.
Early-stop threshold is read from TARGET_EVAL_LOSS, but the MLPerf example scripts/configs in this PR set MLLOG_TARGET_EVAL_LOSS. As-is, early stopping won't trigger unless users also export TARGET_EVAL_LOSS. Consider reading MLLOG_TARGET_EVAL_LOSS (or supporting both names with a clear precedence).
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) | |
| target_eval_loss_env = os.environ.get("TARGET_EVAL_LOSS") | |
| if target_eval_loss_env is None: | |
| target_eval_loss_env = os.environ.get("MLLOG_TARGET_EVAL_LOSS", "0") | |
| target_eval_loss = float(target_eval_loss_env) |
| if target_eval_loss > 0 and hasattr(args, "_eval_val_loss"): | ||
| if args._eval_val_loss <= target_eval_loss: | ||
| run_duration = time.time() - train_start_time | ||
| log_rank_0( | ||
| f"[EarlyStop] Reached target! Stopping training. " | ||
| f"eval_loss: {args._eval_val_loss:.6f} (target: {target_eval_loss}) | " |
There was a problem hiding this comment.
The early-stop decision is based on args._eval_val_loss, but that attribute is only set on the pipeline stage that computes loss (see primus/backends/megatron/training/evaluator.py). With pipeline parallelism > 1, only last-stage ranks will update args.train_iters/args.do_valid, while other ranks keep training, which can deadlock. Please synchronize the stop condition across all ranks (e.g., broadcast the eval loss/stop flag and apply the same train_iters update everywhere).
| if target_eval_loss > 0 and hasattr(args, "_eval_val_loss"): | |
| if args._eval_val_loss <= target_eval_loss: | |
| run_duration = time.time() - train_start_time | |
| log_rank_0( | |
| f"[EarlyStop] Reached target! Stopping training. " | |
| f"eval_loss: {args._eval_val_loss:.6f} (target: {target_eval_loss}) | " | |
| if target_eval_loss > 0: | |
| local_eval_val_loss = getattr(args, "_eval_val_loss", None) | |
| stop_training = ( | |
| local_eval_val_loss is not None and local_eval_val_loss <= target_eval_loss | |
| ) | |
| synced_eval_val_loss = local_eval_val_loss | |
| if dist.is_available() and dist.is_initialized(): | |
| sync_device = ( | |
| torch.device("cuda", torch.cuda.current_device()) | |
| if torch.cuda.is_available() | |
| else torch.device("cpu") | |
| ) | |
| stop_tensor = torch.tensor( | |
| [1 if stop_training else 0], device=sync_device, dtype=torch.int32 | |
| ) | |
| eval_loss_tensor = torch.tensor( | |
| [ | |
| local_eval_val_loss | |
| if local_eval_val_loss is not None | |
| else float("inf") | |
| ], | |
| device=sync_device, | |
| dtype=torch.float32, | |
| ) | |
| dist.all_reduce(stop_tensor, op=dist.ReduceOp.MAX) | |
| dist.all_reduce(eval_loss_tensor, op=dist.ReduceOp.MIN) | |
| stop_training = bool(stop_tensor.item()) | |
| synced_eval_val_loss = eval_loss_tensor.item() | |
| if synced_eval_val_loss == float("inf"): | |
| synced_eval_val_loss = None | |
| if stop_training: | |
| run_duration = time.time() - train_start_time | |
| eval_loss_for_log = ( | |
| synced_eval_val_loss | |
| if synced_eval_val_loss is not None | |
| else float("nan") | |
| ) | |
| log_rank_0( | |
| f"[EarlyStop] Reached target! Stopping training. " | |
| f"eval_loss: {eval_loss_for_log:.6f} (target: {target_eval_loss}) | " |
| run_duration = time.time() - train_start_time | ||
| target_eval_loss = float(os.environ.get("TARGET_EVAL_LOSS", "0")) | ||
| final_eval_loss = getattr(args, "_eval_val_loss", None) | ||
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" |
There was a problem hiding this comment.
status uses final_eval_loss truthiness (final_eval_loss and ...). If the final loss is 0.0 (or any falsy float), this will incorrectly report aborted even when the target is met. Use an explicit final_eval_loss is not None check (and consider handling NaN separately if needed).
| status = "success" if (final_eval_loss and target_eval_loss > 0 and final_eval_loss <= target_eval_loss) else "aborted" | |
| status = ( | |
| "success" | |
| if ( | |
| final_eval_loss is not None | |
| and target_eval_loss > 0 | |
| and final_eval_loss <= target_eval_loss | |
| ) | |
| else "aborted" | |
| ) |
| if "lm loss" in total_loss_dict: | ||
| val = total_loss_dict["lm loss"] | ||
| args._eval_val_loss = val.item() if hasattr(val, "item") else float(val) |
There was a problem hiding this comment.
args._eval_val_loss is only set on ranks where is_pipeline_stage_containing_loss() is true. Any logic that relies on this value (e.g., early stopping) must broadcast/synchronize it to all ranks; otherwise different ranks will make different control-flow decisions and can hang with pipeline parallelism > 1.
| phase_transition_samples = ( | ||
| [0] | ||
| + [t * args.global_batch_size for t in args.phase_transition_iterations] | ||
| + [args.train_samples] |
There was a problem hiding this comment.
In the phase-transition branch, phase_transition_samples appends args.train_samples, but train_samples may have been computed from args.train_iters * global_batch_size when args.train_samples is unset/None. Appending None here will break comparisons/arithmetic. Use the computed train_samples variable (or args.train_iters * args.global_batch_size) instead of args.train_samples.
| + [args.train_samples] | |
| + [train_samples] |
| train_data_path: "/data/mlperf/data/c4-train.en_6_text_document" | ||
| valid_data_path: "/data/mlperf/data/c4-validation-91205-samples.en_text_document" | ||
| test_data_path: null | ||
| seq_length: 8192 |
There was a problem hiding this comment.
seq_length is defined again here, duplicating the earlier seq_length setting. Please remove one of the duplicate keys to avoid ambiguity and YAML-parser incompatibilities.
| seq_length: 8192 |
| set -e | ||
|
|
||
| mkdir -p /results |
There was a problem hiding this comment.
With set -e, a non-zero torchrun exit will terminate the script immediately, so ret_code=$? and the explicit failure handling below never run. Consider disabling -e around torchrun (or using torchrun ...; ret_code=$? with set +e/set -e guards) so timing/result logging works on failures too.
| torchrun \ | ||
| --nproc_per_node=${GPUS_PER_NODE} \ | ||
| --nnodes=${NNODES} \ | ||
| --node_rank=${NODE_RANK} \ | ||
| --master_addr=${MASTER_ADDR} \ | ||
| --master_port=${MASTER_PORT} \ | ||
| src/train.py | ||
|
|
||
| ret_code=$? | ||
|
|
There was a problem hiding this comment.
This ret_code=$? check is ineffective when the script runs with set -e (the script exits immediately if torchrun fails). If you want to handle failures explicitly, capture the exit code by temporarily disabling -e or by using torchrun ... || ret_code=$? and then re-enable -e.
| torchrun \ | |
| --nproc_per_node=${GPUS_PER_NODE} \ | |
| --nnodes=${NNODES} \ | |
| --node_rank=${NODE_RANK} \ | |
| --master_addr=${MASTER_ADDR} \ | |
| --master_port=${MASTER_PORT} \ | |
| src/train.py | |
| ret_code=$? | |
| ret_code=0 | |
| torchrun \ | |
| --nproc_per_node=${GPUS_PER_NODE} \ | |
| --nnodes=${NNODES} \ | |
| --node_rank=${NODE_RANK} \ | |
| --master_addr=${MASTER_ADDR} \ | |
| --master_port=${MASTER_PORT} \ | |
| src/train.py || ret_code=$? |
| TORCHPROF_VERBOSE = os.getenv("TORCHPROF_VERBOSE", 1) | ||
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | ||
| TORCHPROF_MAXROWS = os.getenv("TORCHPROF_MAXROWS", 100) |
There was a problem hiding this comment.
TORCHPROF_VERBOSE and TORCHPROF_MAXROWS are read via os.getenv without casting, so they become strings. TORCHPROF_MAXROWS is later passed to table(row_limit=...), which expects an int. Please cast these with int(...) (and then to bool for verbose).
| TORCHPROF_VERBOSE = os.getenv("TORCHPROF_VERBOSE", 1) | |
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | |
| TORCHPROF_MAXROWS = os.getenv("TORCHPROF_MAXROWS", 100) | |
| TORCHPROF_VERBOSE = bool(int(os.getenv("TORCHPROF_VERBOSE", 1))) | |
| TORCHPROF_DEVICES = os.getenv("TORCHPROF_DEVICES", "GPU") | |
| TORCHPROF_MAXROWS = int(os.getenv("TORCHPROF_MAXROWS", 100)) |
| TORCHPROF_PROFILE_MEMORY = bool(os.getenv("TORCHPROF_PROFILE_MEMORY", 1)) | ||
| TORCHPROF_WITH_STACK = bool(os.getenv("TORCHPROF_WITH_STACK", 0)) | ||
| TORCHPROF_RECORD_SHAPES = bool(os.getenv("TORCHPROF_RECORD_SHAPES", 1)) | ||
| TORCHPROF_WITH_FLOPS = bool(os.getenv("TORCHPROF_WITH_FLOPS", 1)) |
There was a problem hiding this comment.
bool(os.getenv(...)) treats any non-empty string as True, so values like '0' will incorrectly enable options such as TORCHPROF_PROFILE_MEMORY/TORCHPROF_WITH_STACK. Parse these as integers first (e.g., bool(int(os.getenv(..., '0')))), or implement a small strtobool helper.
| TORCHPROF_PROFILE_MEMORY = bool(os.getenv("TORCHPROF_PROFILE_MEMORY", 1)) | |
| TORCHPROF_WITH_STACK = bool(os.getenv("TORCHPROF_WITH_STACK", 0)) | |
| TORCHPROF_RECORD_SHAPES = bool(os.getenv("TORCHPROF_RECORD_SHAPES", 1)) | |
| TORCHPROF_WITH_FLOPS = bool(os.getenv("TORCHPROF_WITH_FLOPS", 1)) | |
| TORCHPROF_PROFILE_MEMORY = bool(int(os.getenv("TORCHPROF_PROFILE_MEMORY", "1"))) | |
| TORCHPROF_WITH_STACK = bool(int(os.getenv("TORCHPROF_WITH_STACK", "0"))) | |
| TORCHPROF_RECORD_SHAPES = bool(int(os.getenv("TORCHPROF_RECORD_SHAPES", "1"))) | |
| TORCHPROF_WITH_FLOPS = bool(int(os.getenv("TORCHPROF_WITH_FLOPS", "1"))) |
This PR adds MLPERF support for llama3.1 8B FP8.
Time to train: 98.0 mins; eval_loss: 3.292395 (target: 3.3)
Full Log:
mlperf_primus_llama3.18b.log