diff --git a/run_train.sh b/run_train.sh index 87558a782d..ab715e30bb 100755 --- a/run_train.sh +++ b/run_train.sh @@ -13,7 +13,7 @@ set -ex # COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU # COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +export LOG_RANK=${LOG_RANK:-0,2} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} # COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index 5def0e40e6..cbc39989d3 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -43,10 +43,11 @@ def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool: ) if dual_pipe_v and job_config.activation_checkpoint.mode != "none": - raise NotImplementedError( - "Expert Parallel with DualPipeV and Activation Checkpointing " - "cannot be used together. Please disable one of them." - ) + pass + # raise NotImplementedError( + # "Expert Parallel with DualPipeV and Activation Checkpointing " + # "cannot be used together. Please disable one of them." + # ) return dual_pipe_v @@ -98,6 +99,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) +# Thread-local flag to track if we're in the backward thread +# Any SyncHook.forward call from the backward thread is checkpoint recomputation +_backward_thread_flag = threading.local() + + class HookCoordinator: def __init__(self): # Barrier for 2 threads (forward and backward) to synchronize @@ -141,6 +147,16 @@ def is_coordination_enabled(self): return self._coordination_enabled +def _is_in_backward_thread() -> bool: + """Check if current thread is the backward thread.""" + return getattr(_backward_thread_flag, 'value', False) + + +def _set_backward_thread_flag(value: bool): + """Set the backward thread flag for current thread.""" + _backward_thread_flag.value = value + + # Global coordinator _hook_coordinator = HookCoordinator() @@ -150,6 +166,16 @@ class SyncHook(torch.autograd.Function): # pyrefly: ignore [bad-override] def forward(ctx, x, hook_name=""): ctx.hook_name = hook_name + + # Skip barrier if we're in the backward thread - this means we're being called + # during checkpoint recomputation (the forward thread never sets this flag) + if _is_in_backward_thread(): + print("skipping backward barrier", flush=True) + ctx.skip_backward_barrier = True + return x + + ctx.skip_backward_barrier = False + # handle edge case for transformer level boundary if _hook_coordinator._coordination_enabled and hook_name == "D": _hook_coordinator._cycle_count += 1 @@ -165,6 +191,13 @@ def forward(ctx, x, hook_name=""): def backward(ctx, grad_output): hook_name = ctx.hook_name + # Skip barrier if this backward corresponds to a checkpoint recompute forward + # These are "extra" backward nodes created by checkpoint that don't have + # corresponding partners in the other thread + if ctx.skip_backward_barrier: + print("skipping backward barrier", flush=True) + return grad_output, None + # Edge case, skip initial barrier, all subsequent backward hooks will acquire if hook_name == "D" and _hook_coordinator._cycle_count == 0: return grad_output, None @@ -184,6 +217,9 @@ def _count_moe_modules(model): return moe_count +# import fbvscode +# fbvscode.attach_debugger() + device_type, device_module = get_device_info() @@ -264,6 +300,10 @@ def overlap_callback(action: _Action, ctx: _PipelineContext): # Shared container for exception from backward thread def run_backward(): + # Mark this thread as the backward thread so SyncHook.forward + # can detect checkpoint recomputation (forward called from backward thread) + _set_backward_thread_flag(True) + # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(backward_stage) # Set the backward thread to use the same stream as forward @@ -294,6 +334,24 @@ def run_backward(): # pyrefly: ignore [bad-argument-type] backward_mb_index, ) + backward_stage.backward_one_chunk( + backward_mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + + if backward_is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + backward_mb_index, + ) + except BaseException as e: + backward_exception.append(e) + # Abort barrier to unblock forward thread if it's waiting + _hook_coordinator.disable_coordination() + finally: + _set_backward_thread_flag(False) def run_forward(): # pyrefly: ignore [missing-attribute] @@ -306,6 +364,11 @@ def run_forward(): # pyrefly: ignore [bad-index, unsupported-operation] kwarg_mbs[forward_mb_index], ) + # # TODO its error prone to have this logic scattered inside and outside the runtime file.. + # # this goes along with the patch to pytorch: https://github.com/pytorch/pytorch/pull/167002/ + # key = f"{forward_stage.stage_index}_{forward_mb_index}" + # assert key not in schedule.ownership_tokens + # schedule.ownership_tokens[key] = output.view_as(output).grad_fn schedule._maybe_compute_loss( forward_stage, output, ctx.target_mbs, forward_mb_index ) @@ -323,3 +386,7 @@ def run_forward(): thread.join() _hook_coordinator.disable_coordination() + + # Re-raise exception from backward thread with full traceback + if backward_exception: + raise backward_exception[0] diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index bef597be24..9ed6cf2545 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -13,7 +13,6 @@ import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining import PipelineStage - from torch.distributed.pipelining.schedules import ( _PipelineSchedule, _PipelineScheduleRuntime, @@ -39,6 +38,258 @@ "pipeline_module_split", ] +lib = torch.library.Library("aten", "IMPL") + + +def _override_torch_ops_for_zero_bubble(): + class MmSeparateWeightGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, i, w, real_output): + ctx.save_for_backward(i) + return real_output + + @staticmethod + def backward(ctx, grad_output): + (i,) = ctx.saved_tensors + grad_weight = i.t().mm(grad_output) + return None, grad_weight, None + + class MmSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, i, w, real_output): + ctx.save_for_backward(w) + return real_output + + @staticmethod + def backward(ctx, grad_output): + (w,) = ctx.saved_tensors + grad_input = grad_output.mm(w.t()) + return grad_input, None, None + + class MmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, mm_output, fake_1, fake_2): + # we computed the mm earlier, so we could reuse its output shape in the separate input/weight functions + # but we need to keep this autograd function to connect the fake_* inputs to the autograd graph and pass + # gradients back to them + return mm_output + + @staticmethod + def backward(ctx, gO): + return None, gO, gO + + def split_mm(i, w): + # Apply the pass-through node. y is passed to this node so that it can be + # saved for backward, but detach because we don't want to actually build + # this edge of the graph + with torch._C._AutoDispatchBelowAutograd(): + real_output = torch.mm(i.detach(), w.detach()).detach() + + fake_1 = MmSeparateWeightGrad.apply(i.detach(), w, real_output) + fake_2 = MmSeparateInputGrad.apply(i, w.detach(), real_output) + + return MmPassThrough.apply(real_output, fake_1, fake_2) + + # addmm operator: out = beta * input + alpha * (mat1 @ mat2) + class AddmmSeparateMat2Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, mat1, mat2, alpha): + ctx.save_for_backward(mat1) + ctx.alpha = alpha + return mat2 + + @staticmethod + def backward(ctx, grad_output): + (mat1,) = ctx.saved_tensors + # Gradient w.r.t. mat2: alpha * mat1.T @ grad_output + grad_mat2 = mat1.t().mm(grad_output) * ctx.alpha + return None, grad_mat2, None + + class AddmmSeparateMat1Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, mat1, mat2, alpha): + ctx.save_for_backward(mat2) + ctx.alpha = alpha + return mat1 + + @staticmethod + def backward(ctx, grad_output): + (mat2,) = ctx.saved_tensors + # Gradient w.r.t. mat1: alpha * grad_output @ mat2.T + grad_mat1 = grad_output.mm(mat2.t()) * ctx.alpha + return grad_mat1, None, None + + class AddmmSeparateBiasGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, bias, beta): + ctx.beta = beta + return bias + + @staticmethod + def backward(ctx, grad_output): + # Gradient w.r.t. bias: beta * sum(grad_output, dim=0) + grad_bias = grad_output.sum(dim=0) * ctx.beta + return grad_bias, None + + class AddmmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, bias, mat1, mat2, beta, alpha): + with torch._C._AutoDispatchBelowAutograd(): + return torch.addmm(bias, mat1, mat2, beta=beta, alpha=alpha) + + @staticmethod + def backward(ctx, gO): + return gO, gO, gO, None, None + + def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1): + mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha) + mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha) + bias_1 = AddmmSeparateBiasGrad.apply(bias, beta) + return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha) + + # rms_norm operator: RMS normalization + class FusedRmsNormSeparateWeightGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd): + ctx.save_for_backward(input, weight, rstd) + ctx.normalized_shape = normalized_shape + return real_output + + @staticmethod + def backward(ctx, grad_output): + input, weight, rstd = ctx.saved_tensors + # Call _fused_rms_norm_backward with output_mask=[False, True] + # We only want gradient w.r.t. weight (index 1) + _, grad_weight = torch.ops.aten._fused_rms_norm_backward( + grad_output, + input, + ctx.normalized_shape, + rstd, + weight, + output_mask=[False, True], + ) + return None, None, grad_weight, None, None, None + + class FusedRmsNormSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, normalized_shape, weight, eps, real_output, rstd): + ctx.save_for_backward(input, weight, rstd) + ctx.normalized_shape = normalized_shape + return real_output + + @staticmethod + def backward(ctx, grad_output): + input, weight, rstd = ctx.saved_tensors + # Call _fused_rms_norm_backward with output_mask=[True, False] + # We only want gradient w.r.t. input (index 0) + grad_input, _ = torch.ops.aten._fused_rms_norm_backward( + grad_output, + input, + ctx.normalized_shape, + rstd, + weight, + output_mask=[True, False], + ) + return grad_input, None, None, None, None, None + + class FusedRmsNormPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, real_output, real_std, fake_1, fake_2): + return real_output, real_std + + @staticmethod + def backward(ctx, gO, gStd): + # Pass gradients to fake_1 and fake_2 to trigger their backward methods + # Return None for real_output/rstd since they are already detached + return None, None, gO, gO + + def split_fused_rms_norm(input, normalized_shape, weight=None, eps=None): + # Compute the actual output using _fused_rms_norm which returns (output, rstd) + with torch._C._AutoDispatchBelowAutograd(): + real_output, rstd = torch._fused_rms_norm( + input.detach(), + normalized_shape, + weight.detach() if weight is not None else None, + eps, + ) + real_output = real_output.detach() + rstd = rstd.detach() + rstd2 = rstd.clone().detach() + + weight_1 = FusedRmsNormSeparateWeightGrad.apply( + input.detach(), normalized_shape, weight, eps, real_output, rstd + ) + input_1 = FusedRmsNormSeparateInputGrad.apply( + input, + normalized_shape, + weight.detach() if weight is not None else None, + eps, + real_output, + rstd2, + ) + return FusedRmsNormPassThrough.apply(real_output, rstd, weight_1, input_1) + + # _grouped_mm operator: Grouped matrix multiplication for MoE + class GroupedMmSeparateMat2Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mat2, offs, bias, out_dtype, real_output): + ctx.save_for_backward(input) + ctx.offs = offs + return real_output + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + # Gradient w.r.t. mat2 for grouped mm + grad_mat2 = torch.ops.aten._grouped_mm.default( + input.transpose(-1, -2), grad_output, offs=ctx.offs + ) + return None, grad_mat2, None, None, None, None + + class GroupedMmSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mat2, offs, bias, out_dtype, real_output): + ctx.save_for_backward(mat2) + ctx.offs = offs + return real_output + + @staticmethod + def backward(ctx, grad_output): + (mat2,) = ctx.saved_tensors + # Gradient w.r.t. input for grouped mm + grad_input = torch.ops.aten._grouped_mm.default( + grad_output, mat2.transpose(-1, -2), offs=ctx.offs + ) + return grad_input, None, None, None, None, None + + class GroupedMmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, real_output, fake_1, fake_2): + return real_output + + @staticmethod + def backward(ctx, gO): + return None, gO, gO + + def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None): + with torch._C._AutoDispatchBelowAutograd(): + real_output = torch.ops.aten._grouped_mm.default( + input, mat2, offs=offs, bias=bias, out_dtype=out_dtype + ).detach() + fake_1 = GroupedMmSeparateMat2Grad.apply( + input.detach(), mat2, offs, bias, out_dtype, real_output + ) + fake_2 = GroupedMmSeparateInputGrad.apply( + input, mat2.detach(), offs, bias, out_dtype, real_output + ) + return GroupedMmPassThrough.apply(real_output, fake_1, fake_2) + + lib.impl("mm", split_mm, "Autograd") + lib.impl("addmm", split_addmm, "Autograd") + lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd") + lib.impl("_grouped_mm", split_grouped_mm, "Autograd") + torch.autograd.set_detect_anomaly(True, check_nan=False) + def pipeline_llm( model: nn.Module, @@ -51,6 +302,9 @@ def pipeline_llm( ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = parallel_dims.world_mesh["pp"] + if True: + _override_torch_ops_for_zero_bubble() + # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( job_config.parallelism.pipeline_parallel_schedule diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 31e450eb04..5e50b68d56 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -32,7 +32,7 @@ dim=256, inter_dim=1024, moe_inter_dim=256, - n_layers=6, + n_layers=24, n_dense_layers=1, n_heads=16, moe_args=MoEArgs( @@ -97,8 +97,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - attn_type="flex", - attn_mask_type="block_causal", + # attn_type="flex", + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 1951cc4350..f62846b699 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -6,7 +6,8 @@ print_config = false [profiling] enable_profiling = false save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 1 +profiler_warmup = 0 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" @@ -30,17 +31,18 @@ lr = 8e-4 eps = 1e-8 [lr_scheduler] -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +warmup_steps = 0 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" min_lr_factor = 0.0 [training] -local_batch_size = 8 -seq_len = 2048 +local_batch_size = 4 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 6 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 @@ -48,10 +50,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" +pipeline_parallel_degree = 2 +expert_parallel_degree = 2 context_parallel_degree = 1 -expert_parallel_degree = 1 +pipeline_parallel_schedule = "DualPipeV" expert_tensor_parallel_degree = 1 [checkpoint] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..6e9f1287b5 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -38,8 +38,8 @@ min_lr_factor = 0.1 local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 -dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) +steps = 30 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] data_parallel_replicate_degree = 1 @@ -47,9 +47,9 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +pipeline_parallel_degree = 2 +pipeline_parallel_schedule = "DualPipeV" +expert_parallel_degree = 4 expert_tensor_parallel_degree = 1 [checkpoint] diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 112153390f..93e70fd012 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -30,7 +30,6 @@ DualPipeExpertParallel, get_dual_pipe_v_flag, ) - from torchtitan.distributed.expert_parallel import ( BaseExpertParallel, ExpertParallel,