diff --git a/examples/megatron/configs/MI300X/deepseek_v2-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2-BF16-pretrain.yaml index 24f042f4c..234ad2953 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2-BF16-pretrain.yaml @@ -95,5 +95,5 @@ modules: turbo_sync_free_moe_stage: 2 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/deepseek_v2-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2-FP8-pretrain.yaml index 8f4d0570b..42dcc754f 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2-FP8-pretrain.yaml @@ -94,8 +94,8 @@ modules: turbo_sync_free_moe_stage: 2 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml index a43928a3b..d97c22d1e 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml @@ -90,5 +90,5 @@ modules: turbo_sync_free_moe_stage: 2 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/deepseek_v2_lite-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v2_lite-FP8-pretrain.yaml index 794a305e2..950fe9154 100644 --- a/examples/megatron/configs/MI300X/deepseek_v2_lite-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v2_lite-FP8-pretrain.yaml @@ -89,8 +89,8 @@ modules: turbo_sync_free_moe_stage: 2 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/deepseek_v3-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v3-BF16-pretrain.yaml index 825f131e0..df75699f3 100644 --- a/examples/megatron/configs/MI300X/deepseek_v3-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v3-BF16-pretrain.yaml @@ -80,5 +80,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/deepseek_v3-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/deepseek_v3-FP8-pretrain.yaml index b7cfe474f..68ca7064a 100644 --- a/examples/megatron/configs/MI300X/deepseek_v3-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/deepseek_v3-FP8-pretrain.yaml @@ -79,8 +79,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml index 1493fd25f..e0f48d9a4 100755 --- a/examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama2_70B-BF16-pretrain.yaml @@ -77,5 +77,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml index 8f83cfd8d..77d4b52bc 100755 --- a/examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama2_70B-FP8-pretrain.yaml @@ -77,8 +77,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml index 19cb8465f..0ceeeae27 100755 --- a/examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama2_7B-BF16-pretrain.yaml @@ -80,5 +80,5 @@ modules: # sequence_parallel: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml index 73a79b00b..347a3430a 100755 --- a/examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama2_7B-FP8-pretrain.yaml @@ -80,8 +80,8 @@ modules: # sequence_parallel: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml index bb2454f1e..5b67458b8 100644 --- a/examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml @@ -76,5 +76,5 @@ modules: use_turbo_attention: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml index 9013b0c96..a94a66c2e 100644 --- a/examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml @@ -76,8 +76,8 @@ modules: use_turbo_attention: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml index 4d49c5c09..6a4976ada 100644 --- a/examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml @@ -72,5 +72,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml index 9e4b65560..148be586a 100644 --- a/examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml @@ -72,8 +72,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama3_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama3_70B-BF16-pretrain.yaml index 468c56827..68d035a2b 100755 --- a/examples/megatron/configs/MI300X/llama3_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3_70B-BF16-pretrain.yaml @@ -77,5 +77,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama3_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama3_70B-FP8-pretrain.yaml index 8e4ee47a2..23c157f62 100755 --- a/examples/megatron/configs/MI300X/llama3_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3_70B-FP8-pretrain.yaml @@ -77,8 +77,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/llama3_8B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/llama3_8B-BF16-pretrain.yaml index 7d2588f95..6af89e06b 100644 --- a/examples/megatron/configs/MI300X/llama3_8B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3_8B-BF16-pretrain.yaml @@ -73,5 +73,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/llama3_8B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/llama3_8B-FP8-pretrain.yaml index ef9149801..0ad398267 100644 --- a/examples/megatron/configs/MI300X/llama3_8B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/llama3_8B-FP8-pretrain.yaml @@ -73,8 +73,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml new file mode 100644 index 000000000..469913761 --- /dev/null +++ b/examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml @@ -0,0 +1,85 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_370M-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: mamba_370M.yaml + overrides: + # log + wandb_project: "Primus_Mamba_Pretrain" + # disable_wandb: false + # disable_tensorboard: false + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 50 + micro_batch_size: 4 + global_batch_size: 256 + + seq_length: 2048 + max_position_embeddings: 2048 + + lr: 3.0e-4 + min_lr: 0.0 + lr_warmup_iters: 50000 + lr_decay_iters: 73192188 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.02 + norm_epsilon: 1.0e-5 + + # Mamba-specific: must provide spec + spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: EleutherAI/gpt-neox-20b + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + + # Turbo - may need to disable for Mamba if not supported + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false + + # Cross entropy flags + # cross_entropy_fusion_impl: "native" + # cross_entropy_loss_fusion: false diff --git a/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-BF16-pretrain.yaml index 0efc4ee0b..73e949215 100644 --- a/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-BF16-pretrain.yaml @@ -73,5 +73,5 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-FP8-pretrain.yaml index 520177a76..14a554c5e 100644 --- a/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mixtral_8x22B_v0.1-FP8-pretrain.yaml @@ -72,8 +72,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml index eb5ce376f..28a41749e 100644 --- a/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-BF16-pretrain.yaml @@ -68,5 +68,5 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-FP8-pretrain.yaml index 35fa84415..c92eea95f 100644 --- a/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/mixtral_8x7B_v0.1-FP8-pretrain.yaml @@ -67,8 +67,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml index e02c3a233..ab300f580 100644 --- a/examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/qwen2.5_72B-BF16-pretrain.yaml @@ -81,5 +81,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml index 43f19a513..84317c0d5 100644 --- a/examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/qwen2.5_72B-FP8-pretrain.yaml @@ -81,8 +81,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml b/examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml index 8ab17cbdd..b9d9fbe63 100644 --- a/examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI300X/qwen2.5_7B-BF16-pretrain.yaml @@ -74,5 +74,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml b/examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml index 4f331233e..93c2fcf68 100644 --- a/examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI300X/qwen2.5_7B-FP8-pretrain.yaml @@ -74,8 +74,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml new file mode 100644 index 000000000..ca65bb754 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_1B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 8 + global_batch_size: 64 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..b8019be76 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 4 + global_batch_size: 32 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-3B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..a7083c069 --- /dev/null +++ b/examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 2 + global_batch_size: 16 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.1-8B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/deepseek_v2-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2-BF16-pretrain.yaml index 987b0d019..5f7d1299c 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2-BF16-pretrain.yaml @@ -94,5 +94,5 @@ modules: turbo_sync_free_moe_stage: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/deepseek_v2-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2-FP8-pretrain.yaml index 328eb3216..8c91eabb0 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2-FP8-pretrain.yaml @@ -93,8 +93,8 @@ modules: turbo_sync_free_moe_stage: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/deepseek_v2_lite-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2_lite-BF16-pretrain.yaml index b706c159f..54051c9c7 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2_lite-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2_lite-BF16-pretrain.yaml @@ -73,7 +73,7 @@ modules: # Turbo enable_primus_turbo: true use_turbo_attention: false - use_turbo_grouped_mlp: false + use_turbo_grouped_mlp: true # deepep use_turbo_deepep: true @@ -87,7 +87,8 @@ modules: # sync-free moe support stage 1-2, 0 means not use sync-free moe # stage 2 is recommended for better performance turbo_sync_free_moe_stage: 1 + moe_use_fused_router_with_aux_score: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/deepseek_v2_lite-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v2_lite-FP8-pretrain.yaml index 4f3a9f679..715d05551 100644 --- a/examples/megatron/configs/MI355X/deepseek_v2_lite-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v2_lite-FP8-pretrain.yaml @@ -88,8 +88,8 @@ modules: turbo_sync_free_moe_stage: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/deepseek_v3-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v3-BF16-pretrain.yaml index 51dd8affa..61bd8aa5d 100644 --- a/examples/megatron/configs/MI355X/deepseek_v3-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v3-BF16-pretrain.yaml @@ -71,6 +71,19 @@ modules: ckpt_format: torch eval_iters: 0 + # deepep + use_turbo_deepep: true + moe_shared_expert_overlap: false + moe_router_dtype: fp32 + + # 64 or 80 for ep8, 32 for ep16-64 is best practice + turbo_deepep_num_cu: 64 + turbo_deepep_use_comm_stream: false + + # sync-free moe support stage 1-2, 0 means not use sync-free moe + # stage 2 is recommended for better performance + turbo_sync_free_moe_stage: 1 + # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/deepseek_v3-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/deepseek_v3-FP8-pretrain.yaml index a40c24449..08e109c7b 100644 --- a/examples/megatron/configs/MI355X/deepseek_v3-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/deepseek_v3-FP8-pretrain.yaml @@ -70,9 +70,22 @@ modules: ckpt_format: torch eval_iters: 0 + # deepep + use_turbo_deepep: true + moe_shared_expert_overlap: false + moe_router_dtype: fp32 + + # 64 or 80 for ep8, 32 for ep16-64 is best practice + turbo_deepep_num_cu: 64 + turbo_deepep_use_comm_stream: false + + # sync-free moe support stage 1-2, 0 means not use sync-free moe + # stage 2 is recommended for better performance + turbo_sync_free_moe_stage: 1 + # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml index 87b933fdd..2d07839f9 100755 --- a/examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama2_70B-BF16-pretrain.yaml @@ -77,5 +77,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama2_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama2_70B-FP8-pretrain.yaml index 903bd2ad6..923c872df 100755 --- a/examples/megatron/configs/MI355X/llama2_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama2_70B-FP8-pretrain.yaml @@ -77,8 +77,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml index 641dbbaf1..607fe5477 100755 --- a/examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama2_7B-BF16-pretrain.yaml @@ -80,5 +80,5 @@ modules: # sequence_parallel: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml index 43548e58e..3c390c575 100755 --- a/examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama2_7B-FP8-pretrain.yaml @@ -47,7 +47,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: false + gradient_accumulation_fusion: true # data mock_data: true @@ -80,8 +80,8 @@ modules: # sequence_parallel: 1 # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml index f03d9c4c1..c707572da 100644 --- a/examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml @@ -17,8 +17,8 @@ modules: log_avg_reset_interval: 50 train_iters: 50 - micro_batch_size: 4 - global_batch_size: 32 + micro_batch_size: 7 + global_batch_size: 56 seq_length: 8192 max_position_embeddings: 8192 @@ -72,5 +72,5 @@ modules: recompute_num_layers: 80 # int # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml index ae82330d5..9975256f7 100644 --- a/examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml @@ -17,8 +17,8 @@ modules: log_avg_reset_interval: 50 train_iters: 50 - micro_batch_size: 3 - global_batch_size: 24 + micro_batch_size: 4 + global_batch_size: 32 seq_length: 8192 max_position_embeddings: 8192 @@ -72,8 +72,8 @@ modules: recompute_num_layers: 80 # int # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml index 59a7f5b8d..83722456d 100644 --- a/examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.1_8B-BF16-pretrain.yaml @@ -67,5 +67,5 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml index 5cba8ee79..1d8dcd9fc 100644 --- a/examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.1_8B-FP8-pretrain.yaml @@ -21,8 +21,8 @@ modules: log_avg_reset_interval: 50 train_iters: 50 - micro_batch_size: 4 - global_batch_size: 256 + micro_batch_size: 6 + global_batch_size: 384 seq_length: 8192 max_position_embeddings: 8192 @@ -45,7 +45,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: false + gradient_accumulation_fusion: true # data mock_data: true @@ -67,8 +67,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml index 3316fbd4d..fa860a013 100644 --- a/examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.3_70B-BF16-pretrain.yaml @@ -77,5 +77,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama3.3_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama3.3_70B-FP8-pretrain.yaml index f79ef9bbd..a7727049e 100644 --- a/examples/megatron/configs/MI355X/llama3.3_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3.3_70B-FP8-pretrain.yaml @@ -77,8 +77,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama3_70B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama3_70B-BF16-pretrain.yaml index ca25b05d2..00554eee9 100755 --- a/examples/megatron/configs/MI355X/llama3_70B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3_70B-BF16-pretrain.yaml @@ -77,5 +77,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama3_70B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama3_70B-FP8-pretrain.yaml index 22bb90699..3de4437eb 100755 --- a/examples/megatron/configs/MI355X/llama3_70B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3_70B-FP8-pretrain.yaml @@ -77,8 +77,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/llama3_8B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/llama3_8B-BF16-pretrain.yaml index 2093e168e..7bfc7b1bb 100644 --- a/examples/megatron/configs/MI355X/llama3_8B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3_8B-BF16-pretrain.yaml @@ -73,5 +73,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/llama3_8B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/llama3_8B-FP8-pretrain.yaml index 1ec470ad5..096345c92 100644 --- a/examples/megatron/configs/MI355X/llama3_8B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/llama3_8B-FP8-pretrain.yaml @@ -73,8 +73,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml index 188465761..0b30cadf5 100644 --- a/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-BF16-pretrain.yaml @@ -73,8 +73,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # Turbo enable_primus_turbo: true diff --git a/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-FP8-pretrain.yaml index ea9beca43..712f4a022 100644 --- a/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/mixtral_8x22B_v0.1-FP8-pretrain.yaml @@ -72,8 +72,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # Turbo enable_primus_turbo: true diff --git a/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml index 16d8ffbcb..709373f3d 100644 --- a/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-BF16-pretrain.yaml @@ -68,8 +68,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # Turbo enable_primus_turbo: true diff --git a/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-FP8-pretrain.yaml index 4951e3d66..ee2684c3c 100644 --- a/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/mixtral_8x7B_v0.1-FP8-pretrain.yaml @@ -67,8 +67,8 @@ modules: ckpt_format: torch # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # Turbo enable_primus_turbo: true diff --git a/examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml index aae072840..812617737 100644 --- a/examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/qwen2.5_72B-BF16-pretrain.yaml @@ -81,5 +81,5 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/qwen2.5_72B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/qwen2.5_72B-FP8-pretrain.yaml index 8cf56be9e..747ff24cf 100644 --- a/examples/megatron/configs/MI355X/qwen2.5_72B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/qwen2.5_72B-FP8-pretrain.yaml @@ -81,8 +81,8 @@ modules: use_turbo_grouped_mlp: false # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml index 0aa928e9b..f82a51992 100644 --- a/examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml +++ b/examples/megatron/configs/MI355X/qwen2.5_7B-BF16-pretrain.yaml @@ -74,5 +74,5 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true diff --git a/examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml index 6e9f44515..55c0762f9 100644 --- a/examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml +++ b/examples/megatron/configs/MI355X/qwen2.5_7B-FP8-pretrain.yaml @@ -46,7 +46,7 @@ modules: expert_model_parallel_size: 1 overlap_grad_reduce: true overlap_param_gather: true - gradient_accumulation_fusion: false + gradient_accumulation_fusion: true # data mock_data: true @@ -74,8 +74,8 @@ modules: use_turbo_grouped_mlp: true # Cross entropy flags - # cross_entropy_fusion_impl: "te" - # cross_entropy_loss_fusion: true + cross_entropy_fusion_impl: "te" + cross_entropy_loss_fusion: true # enable fp8 training fp8: hybrid diff --git a/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml new file mode 100644 index 000000000..67647da43 --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_1B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_1B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_1B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 12 + global_batch_size: 96 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-1B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml new file mode 100644 index 000000000..19006bd72 --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_3B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_3B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_3B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 7 + global_batch_size: 56 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.2-3B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml new file mode 100644 index 000000000..136ac9d3d --- /dev/null +++ b/examples/megatron/configs/MI355X/zebra_llama_8B-pretrain.yaml @@ -0,0 +1,70 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B-pretrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: zebra_llama_8B.yaml + overrides: + # log + wandb_project: "Primus_Zebra_Llama_8B_Pretrain" + stderr_sink_level: DEBUG + + eval_iters: 0 + + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + train_iters: 100 + micro_batch_size: 4 + global_batch_size: 32 + + seq_length: 8192 + max_position_embeddings: 8192 + original_max_position_embeddings: 8192 + + lr: 2.0e-4 + min_lr: 2.0e-5 + lr_warmup_iters: 200 + lr_decay_iters: 10000 + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + + # Mamba-specific: must provide spec + # Use custom hybrid Mamba+MLA spec + spec: ['primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs', 'hybrid_stack_spec'] + + # Tokenizer + tokenizer_type: HuggingFaceTokenizer + tokenizer_model: meta-llama/Llama-3.1-8B + + # parallel + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + overlap_grad_reduce: true + overlap_param_gather: true + gradient_accumulation_fusion: false + + # data + mock_data: true + train_data_path: null + valid_data_path: null + test_data_path: null + + # ckpt + finetune: false + auto_continue_train: false + load: null + save: null + save_interval: 10000 + disable_last_saving: true + ckpt_format: torch diff --git a/examples/megatron_bridge/configs/MI300X/mamba_370M_sft_posttrain.yaml b/examples/megatron_bridge/configs/MI300X/mamba_370M_sft_posttrain.yaml new file mode 100644 index 000000000..42e077143 --- /dev/null +++ b/examples/megatron_bridge/configs/MI300X/mamba_370M_sft_posttrain.yaml @@ -0,0 +1,57 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:mamba_370M_sft_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + + # Model to run + model: mamba_370M.yaml + + overrides: + stderr_sink_level: DEBUG + + # Parallelism configuration + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_dtype: null + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: false + use_megatron_fsdp: false + + # Finetuning-specific params + #pretrained_checkpoint: null + peft: "none" + packed_sequence: false + + # Training configuration + train_iters: 200 + global_batch_size: 128 + micro_batch_size: 4 + seq_length: 2048 + eval_interval: 30 + save_interval: 50 + + # Optimizer configuration + finetune_lr: 5.0e-6 + min_lr: 0.0 + lr_warmup_iters: 50 + lr_decay_iters: null + + # W&B logging + wandb_project: null + wandb_entity: null + wandb_exp_name: null + + # Precision + precision_config: bf16_mixed + comm_overlap_config: null + + # Turbo - disabled for Mamba (not supported) + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false diff --git a/examples/megatron_bridge/configs/MI300X/zebra_llama_1B_sft_posttrain.yaml b/examples/megatron_bridge/configs/MI300X/zebra_llama_1B_sft_posttrain.yaml new file mode 100644 index 000000000..106ef0216 --- /dev/null +++ b/examples/megatron_bridge/configs/MI300X/zebra_llama_1B_sft_posttrain.yaml @@ -0,0 +1,57 @@ +work_group: ${PRIMUS_TEAM:amd} +user_name: ${PRIMUS_USER:root} +exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B_sft_posttrain} +workspace: ${PRIMUS_WORKSPACE:./output} + +modules: + post_trainer: + framework: megatron_bridge + config: sft_trainer.yaml + + # Model to run + model: zebra_llama_1B.yaml + + overrides: + stderr_sink_level: DEBUG + + # Parallelism configuration + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_dtype: null + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: false + use_megatron_fsdp: false + + # Finetuning-specific params + #pretrained_checkpoint: null + peft: "none" + packed_sequence: false + + # Training configuration + train_iters: 200 + global_batch_size: 64 + micro_batch_size: 8 + seq_length: 8192 + eval_interval: 30 + save_interval: 50 + + # Optimizer configuration + finetune_lr: 5.0e-6 + min_lr: 0.0 + lr_warmup_iters: 50 + lr_decay_iters: null + + # W&B logging + wandb_project: null + wandb_entity: null + wandb_exp_name: null + + # Precision + precision_config: bf16_mixed + comm_overlap_config: null + + # Turbo - disabled for hybrid Mamba+MLA (not supported) + enable_primus_turbo: false + use_turbo_attention: false + use_turbo_grouped_mlp: false diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml index da0831030..21970087a 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml @@ -23,7 +23,7 @@ modules: log_freq: 1 training: - local_batch_size: 4 + local_batch_size: 2 seq_len: 8192 mock_data: false steps: 50 diff --git a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml index 93e0b458b..6e09f2dcf 100644 --- a/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml @@ -23,7 +23,7 @@ modules: log_freq: 1 training: - local_batch_size: 3 + local_batch_size: 1 seq_len: 8192 mock_data: false steps: 50 diff --git a/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml index 23de79da8..1e4465d38 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_70B-BF16-pretrain.yaml @@ -23,7 +23,7 @@ modules: log_freq: 1 training: - local_batch_size: 8 + local_batch_size: 3 seq_len: 8192 mock_data: false steps: 10 diff --git a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml index 034d8523a..954b3826d 100644 --- a/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml +++ b/examples/torchtitan/configs/MI355X/llama3.1_70B-FP8-pretrain.yaml @@ -28,7 +28,8 @@ modules: log_freq: 1 training: - local_batch_size: 6 + local_batch_size: 3 + global_batch_size: 96 seq_len: 8192 mock_data: false steps: 10 diff --git a/primus/backends/megatron/core/models/hybrid/__init__.py b/primus/backends/megatron/core/models/hybrid/__init__.py new file mode 100644 index 000000000..05d2d673a --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Hybrid Mamba+MLA layer specifications for Megatron-LM.""" + +from .hybrid_mamba_mla_layer_specs import hybrid_stack_spec + +__all__ = ["hybrid_stack_spec"] diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_block.py b/primus/backends/megatron/core/models/hybrid/hybrid_block.py new file mode 100644 index 000000000..3c4d7def8 --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_block.py @@ -0,0 +1,403 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols +from megatron.core.transformer import TransformerConfig +from torch import Tensor, nn + +# CudaGraphScope is not available in older Megatron versions +try: + from megatron.core.transformer.enums import CudaGraphScope + + HAS_CUDA_GRAPH_SCOPE = True +except ImportError: + CudaGraphScope = None + HAS_CUDA_GRAPH_SCOPE = False + +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.transformer.utils import sharded_state_dict_default +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + make_viewless_tensor, +) + + +@dataclass +class HybridStackSubmodules: + """ + A class for the module specs for the MambaStack. + """ + + mamba_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + mlp_layer: Union[ModuleSpec, type] = IdentityOp + moe_layer: Union[ModuleSpec, type] = IdentityOp + + +class HybridStack(MegatronModule): + """ + Constructor for the HybridStack class. + + Args: + config (TransformerConfig): the model configuration + submodules (MambaStackSubmodules): the submodules for the stack + residual_in_fp32 (bool, optional): whether to do residual connections + in fp32. Defaults to False. + pre_process (bool, optional): whether to include an embedding layer. + Defaults to True. + hybrid_attention_ratio (float, optional): the target ratio of attention layers to + total layers. Defaults to 0.0. + hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total + layers. Defaults to 0.0. + hybrid_override_pattern (str, optional): the hybrid layer pattern to override + with. Defaults to None. + post_layer_norm (bool, optional): whether to include a final layer norm. + Defaults to True. + post_process (bool, optional): whether to include an output layer. + Defaults to True. + device (optional): the device to use. Defaults to None. + dtype (optional): the data type to use. Defaults to None. + pg_collection (ProcessGroupCollection): the required model communication + process groups to use. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: HybridStackSubmodules, + residual_in_fp32=False, + pre_process: bool = True, + hybrid_attention_ratio: float = 0.0, + hybrid_mlp_ratio: float = 0.0, + hybrid_override_pattern: str = None, + post_layer_norm: bool = True, + post_process: bool = True, + device=None, + dtype=None, + pg_collection: ProcessGroupCollection = None, + ) -> None: + super().__init__(config=config) + self.residual_in_fp32 = residual_in_fp32 + self.pre_process = pre_process + self.post_layer_norm = post_layer_norm + self.post_process = post_process + + assert pg_collection is not None, "pg_collection must be provided for MambaStack" + + self.pp_group = pg_collection.pp + self.tp_group = pg_collection.tp + + # Required for pipeline parallel schedules + self.input_tensor = None + + self.hybrid_attention_ratio = hybrid_attention_ratio + self.hybrid_mlp_ratio = hybrid_mlp_ratio + self.hybrid_override_pattern = hybrid_override_pattern + + # Customized layer allocation + # hybrid_mlp_ratio is not used in this hybrid stack. + # It is by default to be always followed by mamba or mla (i.e., mamba + MLP or MLA + MLP) + # By setting hybrid_attention_ratio, attention layers are by default to be distributed uniformly. + self.layer_type_list = self.allocate_layers( + self.config.num_layers, + self.hybrid_attention_ratio, + ) + + pp_layer_offset = 0 + if self.pp_group.size() > 1: + pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( + self.layer_type_list + ) + + print(f"layer_type_list: {self.layer_type_list}") + + self.layers = nn.ModuleList() + for i, layer_type in enumerate(self.layer_type_list): + fp8_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + with fp8_init_context: + if layer_type == LayerSymbols.MAMBA: + layer = build_module( + submodules.mamba_layer, + config=self.config, + residual_in_fp32=residual_in_fp32, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.ATTENTION: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.attention_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MLP: + # Transformer layers apply their own pp_layer_offset + layer = build_module( + submodules.mlp_layer, + config=self.config, + layer_number=i + 1, + pg_collection=pg_collection, + ) + elif layer_type == LayerSymbols.MOE: + # Transformer layers apply their own pp_layer_offset + layer = build_module(submodules.moe_layer, config=self.config, layer_number=i + 1) + else: + assert False, "unexpected layer_type" + self.layers.append(layer) + + # Required for activation recomputation + self.num_layers_per_pipeline_rank = len(self.layers) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_norm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + def allocate_layers(self, num_layers, hybrid_attention_ratio): + layer_type_list = [] + num_attention_layers = int(num_layers // 2 * hybrid_attention_ratio) + num_mamba_layers = num_layers // 2 - num_attention_layers + num_mamba_per_attention_layer = num_mamba_layers // num_attention_layers + + if hybrid_attention_ratio <= 0.5: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [ + LayerSymbols.MAMBA, + LayerSymbols.MLP, + ] * num_mamba_per_attention_layer + layer_type_list += base_block * num_attention_layers + layer_type_list += [LayerSymbols.MAMBA, LayerSymbols.MLP] * ( + num_mamba_layers % num_attention_layers + ) + else: + base_block = [LayerSymbols.ATTENTION, LayerSymbols.MLP] + [LayerSymbols.MAMBA, LayerSymbols.MLP] + layer_type_list += [LayerSymbols.ATTENTION, LayerSymbols.MLP] * ( + num_attention_layers - num_mamba_layers + ) + layer_type_list += base_block * num_mamba_layers + return layer_type_list + + def _select_layers_for_pipeline_parallel(self, layer_type_list): + num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() + + assert self.config.virtual_pipeline_model_parallel_size is None, ( + "The Mamba hybrid model does not currently support " "virtual/interleaved pipeline parallelism" + ) + + offset = self.pp_group.rank() * num_layers_per_pipeline_rank + selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + + return offset, selected_list + + def set_input_tensor(self, input_tensor: Tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int]]]: + """ + Returns the Mamba conv and ssm states shapes per input sequence + if this block contains Mamba layers (this may not be the case with PP > 1). + """ + for layer_type, layer in zip(self.layer_type_list, self.layers): + if layer_type == LayerSymbols.MAMBA: + return layer.mamba_state_shapes_per_request() + return None + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Tensor, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Forward function of the MambaStack class. + + It either returns the Loss values if labels are given or the + final hidden units + + Args: + hidden_states (Union[Tensor, WrappedTensor]): the input tensor. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): the attention mask. + inference_context (BaseInferenceContext): the inference parameters. + rotary_pos_emb (Tensor, optional): the rotary positional embeddings. + Defaults to None. + Returns: + Tensor: the output tensor. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if inference_context and inference_context.is_static_batching(): + # NOTE(bnorick): match BaseInferenceContext attributes for + # mamba_ssm.utils.generation.BaseInferenceContext, + # this hack supports eval + inference_context.max_seqlen = inference_context.max_sequence_length + inference_context.seqlen_offset = inference_context.sequence_len_offset + + if ( + ( + ( + HAS_CUDA_GRAPH_SCOPE + and self.config.cuda_graph_impl == "local" + and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope + ) + or self.config.flash_decode + ) + and inference_context + and inference_context.is_static_batching() + and not self.training + ): + current_batch_size = hidden_states.shape[1] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device="cuda", + ) + else: + sequence_len_offset = None + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with outer_fp8_context: + for layer in self.layers: + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + with inner_fp8_context: + if isinstance(layer, TransformerLayer): + hidden_states, _ = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + ) + else: # MambaLayer + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + ) + + # The attention layer (currently a simplified transformer layer) + # outputs a tuple of (hidden_states, context). Context is intended + # for cross-attention, and is not needed in our model. + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_norm(hidden_states) + + # Ensure that the tensor passed between pipeline parallel stages is + # viewless. See related notes in TransformerBlock and TransformerLayer + return make_viewless_tensor( + inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True + ) + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Optional[tuple] = None, + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """ + Returns a sharded state dictionary for the current object. + + This function constructs a sharded state dictionary by iterating over the layers + in the current object, computing the sharded state dictionary for each layer, + and combining the results into a single dictionary. + + Parameters: + prefix (str): The prefix to use for the state dictionary keys. + sharded_offsets (tuple): The sharded offsets to use for the state dictionary. + metadata (dict): Additional metadata to use when computing the sharded state dictionary. + + Returns: + dict: The sharded state dictionary for the current object. + """ + + sharded_state_dict = {} + layer_prefix = f"{prefix}layers." + + for local_layer_idx, layer in enumerate(self.layers): + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = f"{layer_prefix}{local_layer_idx}." # module list index in MambaBlock + + sharded_prefix = f"{layer_prefix}{global_layer_offset}." + sharded_pp_offset = [] + + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if not module is self.layers: + sharded_state_dict.update( + sharded_state_dict_default( + module, + f"{prefix}{name}.", + sharded_offsets, + metadata, + tp_group=self.tp_group, + ) + ) + + return sharded_state_dict diff --git a/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py new file mode 100644 index 000000000..cb801809d --- /dev/null +++ b/primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, +) +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules +from megatron.core.ssm.mlp_layer import MLPLayer +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.multi_latent_attention import ( + MLASelfAttention, + MLASelfAttentionSubmodules, +) + +# Import HybridStack from relative path +from primus.backends.megatron.core.models.hybrid.hybrid_block import ( + HybridStack, + HybridStackSubmodules, +) + +# Inference layers may not be available in older Megatron versions +# They're only used in hybrid_inference_stack_spec, not the training spec +try: + from megatron.core.tensor_parallel import ( + InferenceLayerNormColumnParallelLinear, + InferenceRowParallelLinear, + ) + + HAS_INFERENCE_LAYERS = True +except ImportError: + # Fallback to regular layers for inference spec + InferenceLayerNormColumnParallelLinear = TELayerNormColumnParallelLinear + InferenceRowParallelLinear = TERowParallelLinear + HAS_INFERENCE_LAYERS = False + +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, +) + +moe = get_moe_module_spec( + use_te=True, + num_experts=8, # Can be any positive integer (must not be None). + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, +) + +hybrid_stack_spec = ModuleSpec( + module=HybridStack, + submodules=HybridStackSubmodules( + mamba_layer=ModuleSpec( + module=MambaLayer, + submodules=MambaLayerSubmodules( + mixer=ModuleSpec( + module=MambaMixer, + params={ + "expand": 1, + "d_conv": 4, + }, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + mamba_bda=get_bias_dropout_add, + ), + ), + attention_layer=ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=MLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=MLASelfAttentionSubmodules( + linear_q_proj=TEColumnParallelLinear, + linear_q_down_proj=TELinear, + linear_q_up_proj=TELayerNormColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + ), + ), + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + moe_layer=ModuleSpec( + # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + ), + ), + ), +) diff --git a/primus/backends/megatron/megatron_pretrain_trainer.py b/primus/backends/megatron/megatron_pretrain_trainer.py index 02f97e626..97ce5e869 100644 --- a/primus/backends/megatron/megatron_pretrain_trainer.py +++ b/primus/backends/megatron/megatron_pretrain_trainer.py @@ -19,13 +19,30 @@ def train(self): from megatron.core.enums import ModelType from megatron.training import pretrain # type: ignore - from pretrain_gpt import ( # type: ignore - forward_step, - train_valid_test_datasets_provider, - ) from primus.core.utils.import_utils import get_model_provider + # Determine model type (gpt or mamba) from backend_args + model_type = getattr(self.backend_args, "model_type", "gpt") + log_rank_0(f"-detected model_type: {model_type}") + + # Import the appropriate training components based on model_type + if model_type == "mamba": + from pretrain_mamba import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + + log_rank_0("Using Mamba model provider and training components") + else: + from pretrain_gpt import ( # type: ignore + forward_step, + train_valid_test_datasets_provider, + ) + + log_rank_0("Using GPT model provider and training components") + + # Configure training components if hasattr(train_valid_test_datasets_provider, "is_distributed"): train_valid_test_datasets_provider.is_distributed = True @@ -49,9 +66,17 @@ def train(self): if "store" in sig.parameters: kwargs["store"] = store + # Get model provider with correct model_type + # Only pass model_type if it's not the default to maintain compatibility + if model_type != "gpt": + model_provider = get_model_provider(model_type=model_type) + else: + model_provider = get_model_provider() + log_rank_0(f"-model_provider: {model_provider}") + wrapped_pretrain( train_valid_test_datasets_provider, - get_model_provider(), + model_provider, ModelType.encoder_or_decoder, forward_step, **kwargs, diff --git a/primus/backends/megatron/patches/mamba_rocm_patches.py b/primus/backends/megatron/patches/mamba_rocm_patches.py new file mode 100644 index 000000000..9d6243247 --- /dev/null +++ b/primus/backends/megatron/patches/mamba_rocm_patches.py @@ -0,0 +1,79 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Mamba ROCm Patches + +Patches for Mamba model compatibility on AMD ROCm GPUs. +Disables Triton buffer_ops in the chunk_state backward pass to avoid +ROCm-specific correctness issues. +""" + +import torch + +from primus.core.patches import PatchContext, register_patch +from primus.modules.module_utils import log_rank_0 + + +def _is_rocm(ctx: PatchContext) -> bool: + """Return True when running on an AMD ROCm platform.""" + return getattr(torch.version, "hip", None) is not None + + +def _make_triton_wrapper(original_fn): + from triton import knobs as _triton_knobs + + def _chunk_state_bwd_db_no_buffer_ops(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): + with _triton_knobs.amd.scope(): + _triton_knobs.amd.use_buffer_ops = False + return original_fn(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) + + return _chunk_state_bwd_db_no_buffer_ops + + +@register_patch( + "megatron.mamba.rocm_chunk_state_bwd_db", + backend="megatron", + phase="before_train", + description=( + "Disable Triton buffer_ops in Mamba _chunk_state_bwd_db backward pass " + "to work around ROCm-specific correctness issues." + ), + condition=_is_rocm, + tags=["rocm", "mamba"], +) +def patch_mamba_rocm_chunk_state_bwd_db(ctx: PatchContext): + """ + Patch mamba_ssm _chunk_state_bwd_db to disable Triton buffer_ops on ROCm. + + The Triton buffer_ops feature can cause correctness issues on AMD GPUs + during the backward pass of the Mamba chunk state computation. This patch + wraps the original function to set use_buffer_ops = False within an AMD + Triton knobs scope. + + Both ``ssd_chunk_state`` (definition) and ``ssd_combined`` (import-time + binding) module namespaces are patched so every call-site picks up the + wrapper regardless of which module it was imported from. + """ + import mamba_ssm.ops.triton.ssd_chunk_state as ssd_chunk_state + import mamba_ssm.ops.triton.ssd_combined as ssd_combined + + original_fn = ssd_chunk_state._chunk_state_bwd_db + wrapped_fn = _make_triton_wrapper(original_fn) + + # Patch the canonical definition in ssd_chunk_state + ssd_chunk_state._chunk_state_bwd_db = wrapped_fn + log_rank_0( + "[Patch:megatron.mamba.rocm_chunk_state_bwd_db] " + "Patched mamba_ssm.ops.triton.ssd_chunk_state._chunk_state_bwd_db" + ) + + # Patch the import-time binding in ssd_combined + ssd_combined._chunk_state_bwd_db = wrapped_fn + log_rank_0( + "[Patch:megatron.mamba.rocm_chunk_state_bwd_db] " + "Patched mamba_ssm.ops.triton.ssd_combined._chunk_state_bwd_db" + ) diff --git a/primus/backends/megatron_bridge/config_utils.py b/primus/backends/megatron_bridge/config_utils.py index 495a74c5f..75c08bf10 100644 --- a/primus/backends/megatron_bridge/config_utils.py +++ b/primus/backends/megatron_bridge/config_utils.py @@ -185,6 +185,43 @@ def _merge_dict_to_dataclass(target: Any, source_dict: dict, path: str = "") -> ) +def _resolve_recipe(recipe: str, flavor: str): + """ + Resolve a recipe module and function by searching multiple namespaces. + + Search order: + 1. primus.backends.megatron_bridge.recipes.{recipe} (Primus-side extensions) + 2. megatron.bridge.recipes.{recipe} (upstream Megatron-Bridge) + + Returns: + Tuple of (module, full_module_path) for the first namespace that + contains the requested *flavor* function. + + Raises: + AssertionError if the recipe cannot be found in any namespace. + """ + search_prefixes = [ + "primus.backends.megatron_bridge.recipes", + "megatron.bridge.recipes", + ] + + for prefix in search_prefixes: + full_module_path = f"{prefix}.{recipe}" + try: + module = importlib.import_module(full_module_path) + except ImportError: + continue + if hasattr(module, flavor): + return module, full_module_path + + # Build a helpful error message listing all paths that were tried. + tried = [f"{p}.{recipe}" for p in search_prefixes] + assert False, ( + f"Recipe loading failed: Function '{flavor}' not found. " + f"Searched modules: {tried}" + ) + + def load_recipe_config(backend_args: SimpleNamespace) -> Any: recipe = backend_args.recipe flavor = backend_args.flavor @@ -193,21 +230,13 @@ def load_recipe_config(backend_args: SimpleNamespace) -> Any: assert recipe, "Recipe must be specified for Megatron-Bridge backend" assert flavor, "Flavor must be specified for Megatron-Bridge backend" - # Construct full module path and function name - full_module_path = f"megatron.bridge.recipes.{recipe}" function_name = flavor - log_rank_0(f"Loading recipe: {full_module_path}.{function_name}()") + # Resolve recipe module from Primus-side extensions or upstream Megatron-Bridge + module, full_module_path = _resolve_recipe(recipe, function_name) - # Import module and get function - try: - module = importlib.import_module(full_module_path) - except ImportError as e: - assert False, f"Recipe loading failed: Cannot import '{full_module_path}': {e}" + log_rank_0(f"Loading recipe: {full_module_path}.{function_name}()") - assert hasattr( - module, function_name - ), f"Recipe loading failed: Function '{function_name}' not found in '{full_module_path}'" recipe_func = getattr(module, function_name) # Convert backend_args to dict once (used for both recipe call and config override) diff --git a/primus/backends/megatron_bridge/megatron_bridge_adapter.py b/primus/backends/megatron_bridge/megatron_bridge_adapter.py index 01bbeb419..747482006 100644 --- a/primus/backends/megatron_bridge/megatron_bridge_adapter.py +++ b/primus/backends/megatron_bridge/megatron_bridge_adapter.py @@ -39,12 +39,13 @@ def __init__(self, framework: str = "megatron_bridge"): super().__init__(framework) self.third_party_dir_name = "Megatron-Bridge" - def load_trainer_class(self, stage: str = "pretrain"): + def load_trainer_class(self, stage: str = "sft"): """ Return the Megatron-Bridge Trainer class for the specified training stage. Args: - stage: Training stage ("sft" for supervised fine-tuning) + stage: Training stage ("sft" for supervised fine-tuning, + "pretrain" also routes to SFT trainer) Returns: Trainer class for the specified stage @@ -52,7 +53,7 @@ def load_trainer_class(self, stage: str = "pretrain"): Raises: ValueError: If stage is not supported """ - if stage == "sft": + if stage in ("pretrain", "sft"): from primus.backends.megatron_bridge.megatron_bridge_posttrain_trainer import ( MegatronBridgePosttrainTrainer, ) diff --git a/primus/backends/megatron_bridge/recipes/__init__.py b/primus/backends/megatron_bridge/recipes/__init__.py new file mode 100644 index 000000000..addc4218f --- /dev/null +++ b/primus/backends/megatron_bridge/recipes/__init__.py @@ -0,0 +1,13 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Primus-side recipes for Megatron-Bridge. + +This package contains recipe functions that extend Megatron-Bridge's built-in +recipes. The recipe loader in config_utils.py searches this namespace first +before falling back to megatron.bridge.recipes. +""" diff --git a/primus/backends/megatron_bridge/recipes/mamba/__init__.py b/primus/backends/megatron_bridge/recipes/mamba/__init__.py new file mode 100644 index 000000000..051d6cf86 --- /dev/null +++ b/primus/backends/megatron_bridge/recipes/mamba/__init__.py @@ -0,0 +1,7 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Mamba recipe extensions for Megatron-Bridge.""" diff --git a/primus/backends/megatron_bridge/recipes/mamba/mamba2.py b/primus/backends/megatron_bridge/recipes/mamba/mamba2.py new file mode 100644 index 000000000..9aa6f2153 --- /dev/null +++ b/primus/backends/megatron_bridge/recipes/mamba/mamba2.py @@ -0,0 +1,340 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Mamba2 finetuning recipes for Megatron-Bridge. + +These recipes extend the upstream Mamba2 pretrain-only recipes with +finetuning support (SFT, LoRA/DoRA). They reuse the model providers +and utility helpers shipped with Megatron-Bridge so that no third-party +code needs to be modified. +""" + +import os + +import torch +from typing_extensions import TypedDict, Unpack + +from megatron.bridge.models.mamba import ( + MambaModelProvider130M, + MambaModelProvider370M, + MambaModelProvider780M, + MambaModelProvider1P3B, + MambaModelProvider2P7B, + NVIDIAMambaHybridProvider8B, + NVIDIAMambaModelProvider8B, +) +from megatron.bridge.recipes.utils.finetune_utils import ( + default_peft_config, + default_squad_config, +) +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig + + +# --------------------------------------------------------------------------- +# Typed kwargs for Mamba2 finetuning recipes +# --------------------------------------------------------------------------- + +class Mamba2FinetuneKwargs(TypedDict, total=False): + """Typed options accepted by Mamba2 finetuning recipe helpers.""" + + # Core identifiers + dir: str | None + name: str + + # Finetuning-specific + pretrained_checkpoint: str | None + peft: str | None + packed_sequence: bool + + # Model configuration + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + pipeline_dtype: torch.dtype | None + virtual_pipeline_model_parallel_size: int | None + context_parallel_size: int + sequence_parallel: bool + + # Training hyperparameters + train_iters: int + global_batch_size: int + micro_batch_size: int + seq_length: int + eval_interval: int + save_interval: int + + # Optimizer + finetune_lr: float + min_lr: float + lr_warmup_iters: int + lr_decay_iters: int | None + + # W&B logging + wandb_project: str | None + wandb_entity: str | None + wandb_exp_name: str | None + + # Precision / overlap configs + precision_config: MixedPrecisionConfig | str | None + comm_overlap_config: CommOverlapConfig | None + + +# --------------------------------------------------------------------------- +# Public finetune recipe helpers +# --------------------------------------------------------------------------- + +def mamba2_130m_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 130M.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=MambaModelProvider130M, **kwargs) + + +def mamba2_370m_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 370M.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=MambaModelProvider370M, **kwargs) + + +def mamba2_780m_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 780M.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=MambaModelProvider780M, **kwargs) + + +def mamba2_1p3b_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 1.3B.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=MambaModelProvider1P3B, **kwargs) + + +def mamba2_2p7b_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 2.7B.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=MambaModelProvider2P7B, **kwargs) + + +def mamba2_8b_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 8B.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 8, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=NVIDIAMambaModelProvider8B, **kwargs) + + +def mamba2_hybrid_8b_finetune_config(**user_kwargs: Unpack[Mamba2FinetuneKwargs]) -> ConfigContainer: + """Return a finetuning config for Mamba2 Hybrid 8B.""" + recommended: Mamba2FinetuneKwargs = { + "tensor_model_parallel_size": 8, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: Mamba2FinetuneKwargs = {**recommended, **user_kwargs} + return _mamba2_finetune_common(model_provider=NVIDIAMambaHybridProvider8B, **kwargs) + + +# --------------------------------------------------------------------------- +# Common finetuning configuration builder +# --------------------------------------------------------------------------- + +def _mamba2_finetune_common( + model_provider: ( + type[MambaModelProvider130M] + | type[MambaModelProvider370M] + | type[MambaModelProvider780M] + | type[MambaModelProvider1P3B] + | type[MambaModelProvider2P7B] + | type[NVIDIAMambaModelProvider8B] + | type[NVIDIAMambaHybridProvider8B] + ), + dir: str | None = None, + name: str = "default", + # Model configuration + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_dtype: torch.dtype | None = None, + virtual_pipeline_model_parallel_size: int | None = None, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + # Finetuning-specific params + pretrained_checkpoint: str | None = None, + peft: str | None = "none", + packed_sequence: bool = False, + # Training hyperparameters + train_iters: int = 1000, + global_batch_size: int = 128, + micro_batch_size: int = 4, + seq_length: int = 2048, + eval_interval: int = 30, + save_interval: int = 50, + # Optimizer + finetune_lr: float = 5.0e-6, + min_lr: float = 0.0, + lr_warmup_iters: int = 50, + lr_decay_iters: int | None = None, + # W&B logging + wandb_project: str | None = None, + wandb_entity: str | None = None, + wandb_exp_name: str | None = None, + # Precision / overlap configs + precision_config: MixedPrecisionConfig | str | None = "bf16_mixed", + comm_overlap_config: CommOverlapConfig | None = None, +) -> ConfigContainer: + """ + Create a finetuning configuration for Mamba 2.x models. + + This mirrors the pretrain ``_mamba2_common`` helper but replaces the + GPT-dataset configuration with a SQuAD-based HFDatasetConfig and adds + support for PEFT adapters (LoRA / DoRA) and pretrained-checkpoint + loading — following the same pattern used by the Llama-3 and NemotronH + finetuning recipes in Megatron-Bridge. + """ + # Setup directories + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + # Model + model_cfg = model_provider( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_dtype=pipeline_dtype, + virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + sequence_parallel=sequence_parallel, + ) + + # Optimizer & scheduler (lower LR for finetuning) + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=lr_decay_iters, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + weight_decay=0.1, + max_lr=finetune_lr, + min_lr=min_lr, + ) + + # PEFT (LoRA / DoRA / None for full SFT) + mamba_target_modules = [ + "linear_qkv", "linear_proj", "linear_fc1", "linear_fc2", + "in_proj", "out_proj", + ] + peft_config = default_peft_config(peft, target_modules=mamba_target_modules) + + # Logger + logger_cfg = LoggerConfig( + log_interval=1, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + wandb_exp_name=wandb_exp_name, + ) + + # Tokenizer — use HuggingFace tokenizer (same default as pretrain recipe) + tokenizer_cfg = TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model="EleutherAI/gpt-neox-20b", + hf_tokenizer_kwargs={"use_fast": True}, + ) + + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=eval_interval, + eval_iters=10, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + optimizer=opt_cfg, + scheduler=scheduler_cfg, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=False, + use_distributed_optimizer=True, + ), + dataset=default_squad_config(seq_length, packed_sequence), + logger=logger_cfg, + tokenizer=tokenizer_cfg, + checkpoint=CheckpointConfig( + save_interval=save_interval, + save=checkpoint_dir, + load=checkpoint_dir, + pretrained_checkpoint=pretrained_checkpoint, + ckpt_format="torch_dist", + dist_ckpt_strictness="log_all", + ), + rng=RNGConfig(seed=5678), + peft=peft_config, + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg + + +__all__ = [ + "mamba2_130m_finetune_config", + "mamba2_370m_finetune_config", + "mamba2_780m_finetune_config", + "mamba2_1p3b_finetune_config", + "mamba2_2p7b_finetune_config", + "mamba2_8b_finetune_config", + "mamba2_hybrid_8b_finetune_config", +] diff --git a/primus/backends/megatron_bridge/recipes/zebra_llama/__init__.py b/primus/backends/megatron_bridge/recipes/zebra_llama/__init__.py new file mode 100644 index 000000000..52a0984b0 --- /dev/null +++ b/primus/backends/megatron_bridge/recipes/zebra_llama/__init__.py @@ -0,0 +1,7 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Zebra Llama (hybrid Mamba+MLA) recipe extensions for Megatron-Bridge.""" diff --git a/primus/backends/megatron_bridge/recipes/zebra_llama/zebra_llama.py b/primus/backends/megatron_bridge/recipes/zebra_llama/zebra_llama.py new file mode 100644 index 000000000..dccac30a5 --- /dev/null +++ b/primus/backends/megatron_bridge/recipes/zebra_llama/zebra_llama.py @@ -0,0 +1,627 @@ +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Zebra Llama (hybrid Mamba+MLA) recipes for Megatron-Bridge. + +These recipes define model providers and training configurations for the +Zebra Llama family of hybrid models that combine Mamba SSM layers with +Multi-Latent Attention (MLA). The model providers live on the Primus side +so that no third-party code needs to be modified. +""" + +import os +from dataclasses import dataclass +from typing import Callable, Literal, Optional, Union + +import torch +from typing_extensions import TypedDict, Unpack + +from megatron.core.models.mamba import MambaModel as MCoreMambaModel +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import ModuleSpec +from megatron.core.transformer.enums import AttnBackend + +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.models.transformer_config import MLATransformerConfig +from megatron.bridge.recipes.utils.dataset_utils import get_blend_fields_from_data_paths +from megatron.bridge.recipes.utils.finetune_utils import ( + default_peft_config, + default_squad_config, +) +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + GPTDatasetConfig, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig +from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size + + +# --------------------------------------------------------------------------- +# Hybrid Mamba+MLA Model Provider (extends MLATransformerConfig) +# --------------------------------------------------------------------------- + + +def _get_hybrid_mamba_mla_stack_spec(config: "ZebraLlamaMambaMLAProvider") -> ModuleSpec: + """Return the Primus hybrid Mamba+MLA stack spec. + + The import is deferred so that the heavyweight module-level construction + inside ``hybrid_mamba_mla_layer_specs`` (TE layers, MoE spec, etc.) only + runs at model-build time, not when the recipe module is first loaded. + """ + from primus.backends.megatron.core.models.hybrid.hybrid_mamba_mla_layer_specs import ( + hybrid_stack_spec, + ) + + return hybrid_stack_spec + + +@dataclass +class ZebraLlamaMambaMLAProvider(MLATransformerConfig, ModelProviderMixin[MCoreMambaModel]): + """Configuration and provider for Zebra Llama hybrid Mamba+MLA models. + + This class combines MLATransformerConfig (for MLA attention parameters) + with MCoreMambaModel to create a hybrid model that interleaves Mamba SSM + layers with Multi-Latent Attention layers. + """ + + # ---- Model configuration ---- + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + share_embeddings_and_output_weights: bool = False + params_dtype: torch.dtype = torch.bfloat16 + fp16: bool = False + bf16: bool = True + is_hybrid_model: bool = True + + # ---- Mamba-specific parameters ---- + mamba_num_groups: int = 8 + hybrid_attention_ratio: float = 0.25 + hybrid_mlp_ratio: float = 0.0 + hybrid_override_pattern: Optional[str] = None + position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none" + rotary_percent: float = 1.0 + seq_len_interpolation_factor: Optional[float] = None + apply_rope_fusion: bool = True + make_vocab_size_divisible_by: int = 128 + add_bias_linear: bool = False + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + attention_backend: AttnBackend = AttnBackend.auto + deallocate_pipeline_outputs: bool = True + bias_dropout_fusion: bool = True + cross_entropy_loss_fusion: bool = True + + mamba_stack_spec: Union[ + ModuleSpec, + Callable[[], ModuleSpec], + Callable[["ZebraLlamaMambaMLAProvider"], ModuleSpec], + ] = _get_hybrid_mamba_mla_stack_spec + + vocab_size: Optional[int] = None + should_pad_vocab: bool = False + hf_model_id: Optional[str] = None + _pg_collection: Optional[ProcessGroupCollection] = None + + restore_modelopt_state: bool = False + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMambaModel: + """Instantiate a Megatron Core Mamba model with MLA attention support.""" + mamba_stack_spec = self.mamba_stack_spec + if not isinstance(mamba_stack_spec, ModuleSpec): + import inspect + + if len(inspect.signature(mamba_stack_spec).parameters) > 0: + mamba_stack_spec = mamba_stack_spec(self) + else: + mamba_stack_spec = mamba_stack_spec() + + assert getattr(self, "virtual_pipeline_model_parallel_size", None) is None and vp_stage is None, ( + "Virtual pipeline model parallelism is temporarily unsupported in SSM/Mamba " + "models due to upstream MCore MambaModel API dependency" + ) + + assert self.vocab_size is not None, "vocab_size must be configured before calling provide()" + if self.should_pad_vocab: + padded_vocab_size = calculate_padded_vocab_size( + self.vocab_size, self.make_vocab_size_divisible_by, self.tensor_model_parallel_size + ) + else: + padded_vocab_size = self.vocab_size + + return MCoreMambaModel( + self, + mamba_stack_spec=mamba_stack_spec, + vocab_size=padded_vocab_size, + max_sequence_length=self.seq_length, + hybrid_attention_ratio=self.hybrid_attention_ratio, + hybrid_mlp_ratio=self.hybrid_mlp_ratio, + hybrid_override_pattern=self.hybrid_override_pattern, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process or is_pp_first_stage(self._pg_collection.pp), + post_process=post_process or is_pp_last_stage(self._pg_collection.pp), + pg_collection=self._pg_collection, + ) + + +# --------------------------------------------------------------------------- +# Zebra Llama 1B Provider +# --------------------------------------------------------------------------- + + +@dataclass +class ZebraLlama1BModelProvider(ZebraLlamaMambaMLAProvider): + """Configuration for Zebra Llama 1B (hybrid Mamba+MLA). + + Architecture summary: + - 32 layers with 25% attention ratio (8 MLA + 24 Mamba layers) + - hidden_size=2048, ffn_hidden_size=8192 + - Multi-Latent Attention with q_lora_rank=1344, kv_lora_rank=128 + - Mamba SSM with state_dim=64, head_dim=64, 8 groups + - SwiGLU activation in MLP layers + - Tokenizer: meta-llama/Llama-3.2-1B + """ + + # Layer counts and sizes + num_layers: int = 32 + hidden_size: int = 2048 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + seq_length: int = 8192 + + # Hybrid Mamba parameters + hybrid_attention_ratio: float = 0.25 + mamba_num_groups: int = 8 + + # MLA parameters + multi_latent_attention: bool = True + q_lora_rank: int = 1344 + kv_lora_rank: int = 128 + qk_head_dim: int = 32 + qk_pos_emb_head_dim: int = 32 + v_head_dim: int = 64 + rotary_scaling_factor: float = 1.0 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + + # SwiGLU activation + gated_linear_unit: bool = True + + # Position embedding — MLA uses its own internal positional encoding + position_embedding_type: Literal["learned_absolute", "rope", "none"] = "none" + rotary_base: float = 500000 + normalization: str = "RMSNorm" + layernorm_epsilon: float = 1e-5 + + +# --------------------------------------------------------------------------- +# Typed kwargs for recipe helpers +# --------------------------------------------------------------------------- + + +class ZebraLlamaPretrainKwargs(TypedDict, total=False): + """Typed options accepted by Zebra Llama pretrain recipe helpers.""" + + dir: str | None + name: str + + # Dataset + data_paths: list[str] | None + data_args_path: str | None + train_data_path: list[str] | None + valid_data_path: list[str] | None + test_data_path: list[str] | None + per_split_data_args_path: str | None + mock: bool + + # Model parallelism + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + pipeline_dtype: torch.dtype | None + virtual_pipeline_model_parallel_size: int | None + context_parallel_size: int + sequence_parallel: bool + + # Training hyperparameters + train_iters: int + global_batch_size: int + micro_batch_size: int + seq_length: int + lr: float + min_lr: float + lr_warmup_iters: int + lr_decay_iters: int | None + + # Tokenizer + use_null_tokenizer: bool + + # Precision / overlap + precision_config: MixedPrecisionConfig | str | None + comm_overlap_config: CommOverlapConfig | None + + +class ZebraLlamaFinetuneKwargs(TypedDict, total=False): + """Typed options accepted by Zebra Llama finetune recipe helpers.""" + + dir: str | None + name: str + + # Finetuning-specific + pretrained_checkpoint: str | None + peft: str | None + packed_sequence: bool + + # Model parallelism + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + pipeline_dtype: torch.dtype | None + virtual_pipeline_model_parallel_size: int | None + context_parallel_size: int + sequence_parallel: bool + + # Training hyperparameters + train_iters: int + global_batch_size: int + micro_batch_size: int + seq_length: int + eval_interval: int + save_interval: int + + # Optimizer + finetune_lr: float + min_lr: float + lr_warmup_iters: int + lr_decay_iters: int | None + + # W&B logging + wandb_project: str | None + wandb_entity: str | None + wandb_exp_name: str | None + + # Precision / overlap + precision_config: MixedPrecisionConfig | str | None + comm_overlap_config: CommOverlapConfig | None + + +# --------------------------------------------------------------------------- +# Public pretrain recipe +# --------------------------------------------------------------------------- + + +def zebra_llama_1b_pretrain_config( + **user_kwargs: Unpack[ZebraLlamaPretrainKwargs], +) -> ConfigContainer: + """Return a pre-training config for Zebra Llama 1B (hybrid Mamba+MLA).""" + recommended: ZebraLlamaPretrainKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + "use_null_tokenizer": False, + } + kwargs: ZebraLlamaPretrainKwargs = {**recommended, **user_kwargs} + return _zebra_llama_pretrain_common( + model_provider=ZebraLlama1BModelProvider, + tokenizer_model="meta-llama/Llama-3.2-1B", + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Public finetune recipe +# --------------------------------------------------------------------------- + + +def zebra_llama_1b_finetune_config( + **user_kwargs: Unpack[ZebraLlamaFinetuneKwargs], +) -> ConfigContainer: + """Return a finetuning config for Zebra Llama 1B (hybrid Mamba+MLA).""" + recommended: ZebraLlamaFinetuneKwargs = { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "sequence_parallel": False, + "precision_config": "bf16_mixed", + } + kwargs: ZebraLlamaFinetuneKwargs = {**recommended, **user_kwargs} + return _zebra_llama_finetune_common( + model_provider=ZebraLlama1BModelProvider, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Common pretrain builder +# --------------------------------------------------------------------------- + + +def _zebra_llama_pretrain_common( + model_provider: type[ZebraLlamaMambaMLAProvider], + tokenizer_model: str | None = None, + dir: str | None = None, + name: str = "default", + # Dataset + data_paths: list[str] | None = None, + data_args_path: str | None = None, + train_data_path: list[str] | None = None, + valid_data_path: list[str] | None = None, + test_data_path: list[str] | None = None, + per_split_data_args_path: str | None = None, + mock: bool = False, + # Model parallelism + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_dtype: torch.dtype | None = None, + virtual_pipeline_model_parallel_size: int | None = None, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + # Training hyperparameters + train_iters: int = 100, + global_batch_size: int = 64, + micro_batch_size: int = 8, + seq_length: int = 8192, + lr: float = 2.0e-4, + min_lr: float = 2.0e-5, + lr_warmup_iters: int = 200, + lr_decay_iters: int | None = 10000, + # Tokenizer + use_null_tokenizer: bool = False, + # Precision / overlap + precision_config: MixedPrecisionConfig | str | None = "bf16_mixed", + comm_overlap_config: CommOverlapConfig | None = None, +) -> ConfigContainer: + """Create a pre-training configuration for Zebra Llama hybrid models.""" + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + blend, blend_per_split, split = get_blend_fields_from_data_paths( + data_paths, data_args_path, train_data_path, valid_data_path, test_data_path, per_split_data_args_path, mock + ) + + model_cfg = model_provider( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_dtype=pipeline_dtype, + virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + sequence_parallel=sequence_parallel, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=lr_decay_iters, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + weight_decay=0.1, + max_lr=lr, + min_lr=min_lr, + ) + + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=100, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + use_distributed_optimizer=True, + ), + dataset=GPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=True, + seq_length=seq_length, + num_dataset_builder_threads=1, + blend=blend, + blend_per_split=blend_per_split, + split=split, + data_sharding=True, + dataloader_type="single", + num_workers=8, + skip_getting_attention_mask_from_dataset=True, + ), + logger=LoggerConfig( + log_interval=1, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=( + TokenizerConfig( + tokenizer_type="NullTokenizer", + tokenizer_model=None, + vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE, + ) + if use_null_tokenizer + else TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=tokenizer_model or "meta-llama/Llama-3.2-1B", + hf_tokenizer_kwargs={"use_fast": True}, + ) + ), + checkpoint=CheckpointConfig( + save_interval=10000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch", + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg + + +# --------------------------------------------------------------------------- +# Common finetune builder +# --------------------------------------------------------------------------- + + +def _zebra_llama_finetune_common( + model_provider: type[ZebraLlamaMambaMLAProvider], + dir: str | None = None, + name: str = "default", + # Model parallelism + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + pipeline_dtype: torch.dtype | None = None, + virtual_pipeline_model_parallel_size: int | None = None, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + # Finetuning-specific params + pretrained_checkpoint: str | None = None, + peft: str | None = "none", + packed_sequence: bool = False, + # Training hyperparameters + train_iters: int = 1000, + global_batch_size: int = 128, + micro_batch_size: int = 4, + seq_length: int = 2048, + eval_interval: int = 30, + save_interval: int = 50, + # Optimizer + finetune_lr: float = 5.0e-6, + min_lr: float = 0.0, + lr_warmup_iters: int = 50, + lr_decay_iters: int | None = None, + # W&B logging + wandb_project: str | None = None, + wandb_entity: str | None = None, + wandb_exp_name: str | None = None, + # Precision / overlap + precision_config: MixedPrecisionConfig | str | None = "bf16_mixed", + comm_overlap_config: CommOverlapConfig | None = None, +) -> ConfigContainer: + """Create a finetuning configuration for Zebra Llama hybrid models.""" + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + model_cfg = model_provider( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + pipeline_dtype=pipeline_dtype, + virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + sequence_parallel=sequence_parallel, + ) + + opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=lr_decay_iters, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + weight_decay=0.1, + max_lr=finetune_lr, + min_lr=min_lr, + ) + + # PEFT target modules (Mamba + attention + MLP projections) + mamba_mla_target_modules = [ + "linear_q_proj", + "linear_q_down_proj", + "linear_q_up_proj", + "linear_kv_down_proj", + "linear_kv_up_proj", + "linear_proj", + "linear_fc1", + "linear_fc2", + "in_proj", + "out_proj", + ] + peft_config = default_peft_config(peft, target_modules=mamba_mla_target_modules) + + logger_cfg = LoggerConfig( + log_interval=1, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + wandb_exp_name=wandb_exp_name, + ) + + tokenizer_cfg = TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model="meta-llama/Llama-3.2-1B", + hf_tokenizer_kwargs={"use_fast": True}, + ) + + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=eval_interval, + eval_iters=10, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + ), + optimizer=opt_cfg, + scheduler=scheduler_cfg, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=False, + use_distributed_optimizer=True, + ), + dataset=default_squad_config(seq_length, packed_sequence), + logger=logger_cfg, + tokenizer=tokenizer_cfg, + checkpoint=CheckpointConfig( + save_interval=save_interval, + save=checkpoint_dir, + load=checkpoint_dir, + pretrained_checkpoint=pretrained_checkpoint, + ckpt_format="torch_dist", + dist_ckpt_strictness="log_all", + ), + rng=RNGConfig(seed=5678), + peft=peft_config, + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg + + +__all__ = [ + "ZebraLlamaMambaMLAProvider", + "ZebraLlama1BModelProvider", + "zebra_llama_1b_pretrain_config", + "zebra_llama_1b_finetune_config", +] diff --git a/primus/configs/models/megatron/deepseek_v2_base.yaml b/primus/configs/models/megatron/deepseek_v2_base.yaml index 6d402635d..f4d645c17 100755 --- a/primus/configs/models/megatron/deepseek_v2_base.yaml +++ b/primus/configs/models/megatron/deepseek_v2_base.yaml @@ -34,6 +34,7 @@ moe_aux_loss_coeff: 0.001 # aux_loss_alpha # rotary rotary_base: 10000 rotary_scaling_factor: 40.0 # float +original_max_position_embeddings: 4096 # int, original context length before YaRN scaling mscale: 0.707 # float mscale_all_dim: 0.707 # float diff --git a/primus/configs/models/megatron/deepseek_v3_base.yaml b/primus/configs/models/megatron/deepseek_v3_base.yaml index 421b99361..c0a5c77af 100755 --- a/primus/configs/models/megatron/deepseek_v3_base.yaml +++ b/primus/configs/models/megatron/deepseek_v3_base.yaml @@ -37,6 +37,7 @@ moe_aux_loss_coeff: 0.001 # aux_loss_alpha # rotary rotary_base: 10000 rotary_scaling_factor: 40.0 # float +original_max_position_embeddings: 4096 # int, original context length before YaRN scaling mscale: 1.0 # float mscale_all_dim: 1.0 # float diff --git a/primus/configs/models/megatron/hybrid_model_base.yaml b/primus/configs/models/megatron/hybrid_model_base.yaml new file mode 100644 index 000000000..f6e9ee113 --- /dev/null +++ b/primus/configs/models/megatron/hybrid_model_base.yaml @@ -0,0 +1,21 @@ +extends: + - language_model.yaml + +# Mamba layer configuration +mamba_state_dim: 128 +mamba_head_dim: 64 +mamba_num_groups: 8 +mamba_num_heads: null +mamba_expand: 2 +mamba_d_conv: 4 +disable_mamba_mem_eff_path: false + +# Hybrid model configuration +is_hybrid_model: false # bool +hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +hybrid_override_pattern: null # str + +# MTP +mtp_num_layers: null # int +mtp_loss_scaling_factor: 0.1 # float diff --git a/primus/configs/models/megatron/language_model.yaml b/primus/configs/models/megatron/language_model.yaml index 8c6f24e48..9d2929c2a 100755 --- a/primus/configs/models/megatron/language_model.yaml +++ b/primus/configs/models/megatron/language_model.yaml @@ -8,6 +8,7 @@ extends: # model architecture use_legacy_models: false deprecated_use_mcore_models: false +model_type: gpt # gpt or mamba num_layers: 24 encoder_num_layers: null decoder_num_layers: null @@ -21,6 +22,7 @@ num_query_groups: null add_position_embedding: false position_embedding_type: learned_absolute max_position_embeddings: null +original_max_position_embeddings: null untie_embeddings_and_output_weights: true ffn_hidden_size: null @@ -99,10 +101,6 @@ rotary_scaling_factor: 1.0 # float mscale: 1.0 # float mscale_all_dim: 1.0 # float -# MTP -mtp_num_layers: null # int -mtp_loss_scaling_factor: 0.1 # float - # MoE related num_experts: null moe_layer_freq: 1 # int diff --git a/primus/configs/models/megatron/mamba_370M.yaml b/primus/configs/models/megatron/mamba_370M.yaml new file mode 100644 index 000000000..6665da3af --- /dev/null +++ b/primus/configs/models/megatron/mamba_370M.yaml @@ -0,0 +1,16 @@ +bases: + - mamba_base.yaml + +# Mamba 370M configuration +model_type: mamba # CRITICAL: Mamba models must use mamba model type +tokenizer_type: GPT2BPETokenizer +vocab_size: 50257 + +# Model size parameters +num_layers: 48 +hidden_size: 1024 +num_attention_heads: 16 # Required by Megatron validation, even for pure Mamba models +ffn_hidden_size: null +mamba_state_dim: 16 +mamba_head_dim: 64 +mamba_num_groups: 8 diff --git a/primus/configs/models/megatron/mamba_base.yaml b/primus/configs/models/megatron/mamba_base.yaml new file mode 100644 index 000000000..d52fe6db2 --- /dev/null +++ b/primus/configs/models/megatron/mamba_base.yaml @@ -0,0 +1,36 @@ +bases: + - language_model.yaml + +# Mamba-specific configuration +# Note: Mamba-specific parameters (spec, is_hybrid_model, mamba_state_dim, etc.) +# must be set in the pretrain config overrides, not here + +model_type: mamba +use_legacy_models: false + +# Position embeddings - Mamba typically doesn't use position embeddings +position_embedding_type: rope +use_rotary_position_embeddings: false + +# Tokenizer (should be set in specific model configs) +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: null + +# Standard transformer settings that may be used by hybrid models +is_hybrid_model: false +attention_dropout: 0.0 +hidden_dropout: 0.0 + +# Embeddings +untie_embeddings_and_output_weights: false + +# Other settings +apply_residual_connection_post_layernorm: false +add_bias_linear: false +swiglu: false + +# Normalization +norm_epsilon: 1.0e-5 + +# Initialization +init_method_std: 0.02 diff --git a/primus/configs/models/megatron/zebra_llama_1B.yaml b/primus/configs/models/megatron/zebra_llama_1B.yaml new file mode 100644 index 000000000..d1afa9531 --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_1B.yaml @@ -0,0 +1,42 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 32 +hidden_size: 2048 +ffn_hidden_size: 8192 + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 64 +mamba_head_dim: 64 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 1344 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 32 # Query-Key head dimension +qk_pos_emb_head_dim: 32 # Positional embedding head dimension +v_head_dim: 64 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron/zebra_llama_3B.yaml b/primus/configs/models/megatron/zebra_llama_3B.yaml new file mode 100644 index 000000000..23090841f --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_3B.yaml @@ -0,0 +1,43 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 56 +hidden_size: 3072 +ffn_hidden_size: 8192 +normalization: "RMSNorm" + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 24 +q_lora_rank: 1536 # Query LoRA rank +kv_lora_rank: 128 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +original_max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron/zebra_llama_8B.yaml b/primus/configs/models/megatron/zebra_llama_8B.yaml new file mode 100644 index 000000000..0237a652d --- /dev/null +++ b/primus/configs/models/megatron/zebra_llama_8B.yaml @@ -0,0 +1,43 @@ +bases: + - mamba_base.yaml + +# Zebra Llama 8B configuration +model_type: mamba # CRITICAL: Hybrid models must use mamba model type +tokenizer_type: HuggingFaceTokenizer +tokenizer_model: meta-llama/Llama-3.2-1B + +# Model size parameters +num_layers: 64 +hidden_size: 4096 +ffn_hidden_size: 14436 +normalization: "RMSNorm" + +# Mamba parameters +is_hybrid_model: true +hybrid_attention_ratio: 0.25 +mamba_state_dim: 128 +mamba_head_dim: 128 +mamba_num_groups: 8 + +# MLA parameters +# Disable standard GQA - MLA uses its own compression via LoRA +group_query_attention: false +swiglu: true +num_query_groups: null +multi_latent_attention: true +num_attention_heads: 32 +q_lora_rank: 2048 # Query LoRA rank +kv_lora_rank: 160 # Key-Value LoRA rank +qk_head_dim: 64 # Query-Key head dimension +qk_pos_emb_head_dim: 64 # Positional embedding head dimension +v_head_dim: 128 # Value head dimension +rotary_scaling_factor: 1.0 +mscale: 1.0 +mscale_all_dim: 1.0 + +# MLA uses its own internal positional encoding +rotary_base: 500000 +position_embedding_type: none +add_position_embedding: true +use_rotary_position_embeddings: false +original_max_position_embeddings: 131072 diff --git a/primus/configs/models/megatron_bridge/mamba_370M.yaml b/primus/configs/models/megatron_bridge/mamba_370M.yaml new file mode 100644 index 000000000..55d45966c --- /dev/null +++ b/primus/configs/models/megatron_bridge/mamba_370M.yaml @@ -0,0 +1,5 @@ +recipe: mamba.mamba2 +flavor: mamba2_370m_finetune_config + +dataset: + dataset_name: "rajpurkar/squad" diff --git a/primus/configs/models/megatron_bridge/zebra_llama_1B.yaml b/primus/configs/models/megatron_bridge/zebra_llama_1B.yaml new file mode 100644 index 000000000..e53c0bb2d --- /dev/null +++ b/primus/configs/models/megatron_bridge/zebra_llama_1B.yaml @@ -0,0 +1,5 @@ +recipe: zebra_llama.zebra_llama +flavor: zebra_llama_1b_finetune_config + +dataset: + mock: true diff --git a/primus/configs/modules/megatron/trainer_base.yaml b/primus/configs/modules/megatron/trainer_base.yaml index 76be0c7ba..ba89bd7d1 100755 --- a/primus/configs/modules/megatron/trainer_base.yaml +++ b/primus/configs/modules/megatron/trainer_base.yaml @@ -305,17 +305,17 @@ rerun_mode: disabled # str: 'disabled', 'validate_results', 'report_stats' # Experimental features enable_experimental: false -# Hybrid model configuration -hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] -hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] -hybrid_override_pattern: null # str - -# Mamba layer configuration -mamba_state_dim: 128 -mamba_head_dim: 64 -mamba_num_groups: 8 -mamba_num_heads: null -disable_mamba_mem_eff_path: false +# # Hybrid model configuration +# hybrid_attention_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_mlp_ratio: 0.0 # float range [0,0, 1.0] +# hybrid_override_pattern: null # str + +# # Mamba layer configuration +# mamba_state_dim: 128 +# mamba_head_dim: 64 +# mamba_num_groups: 8 +# mamba_num_heads: null +# disable_mamba_mem_eff_path: false # Args of precision-aware optimizer use_precision_aware_optimizer: false @@ -405,7 +405,6 @@ indexer_log_interval: 1000 enable_ft_package: false calc_ft_timeouts: false run_workload_inspector_server: false -is_hybrid_model: false heterogeneous_layers_config_path: null heterogeneous_layers_config_encoded_json: null diff --git a/primus/core/utils/import_utils.py b/primus/core/utils/import_utils.py index 2ccd8ebed..34c2de16d 100644 --- a/primus/core/utils/import_utils.py +++ b/primus/core/utils/import_utils.py @@ -34,25 +34,42 @@ def lazy_import(paths, symbol, log_prefix="[Primus]"): raise ImportError(f"{log_prefix} {symbol} not found in any of: {paths}") -def get_model_provider(): +def get_model_provider(model_type="gpt"): """ - Resolve model_provider across Megatron versions. + Resolve model_provider across Megatron versions and model types. - - New: model_provider + gpt_builder + Args: + model_type (str): Type of model - 'gpt' or 'mamba'. Defaults to 'gpt'. + + - New: model_provider + gpt_builder/mamba_builder - Mid: model_provider only - - Old: pretrain_gpt.model_provider + - Old: pretrain_gpt.model_provider / pretrain_mamba.model_provider """ # Try to import model_provider - model_provider = lazy_import( - ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" - ) + if model_type == "mamba": + model_provider = lazy_import( + ["model_provider", "pretrain_mamba"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) + # Try to import mamba_builder (for Mamba models) + try: + mamba_builder = lazy_import( + ["mamba_builders"], "mamba_builder", log_prefix="[Primus][MegatronCompat]" + ) + return partial(model_provider, mamba_builder) + except ImportError: + return model_provider + else: + # Default GPT behavior + model_provider = lazy_import( + ["model_provider", "pretrain_gpt"], "model_provider", log_prefix="[Primus][MegatronCompat]" + ) - # Try to import gpt_builder (only exists in newer versions) - try: - gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") - return partial(model_provider, gpt_builder) - except ImportError: - return model_provider + # Try to import gpt_builder (only exists in newer versions) + try: + gpt_builder = lazy_import(["gpt_builders"], "gpt_builder", log_prefix="[Primus][MegatronCompat]") + return partial(model_provider, gpt_builder) + except ImportError: + return model_provider def get_custom_fsdp(): diff --git a/primus/modules/trainer/megatron/pre_trainer.py b/primus/modules/trainer/megatron/pre_trainer.py index 1c8539dd5..d34788b04 100644 --- a/primus/modules/trainer/megatron/pre_trainer.py +++ b/primus/modules/trainer/megatron/pre_trainer.py @@ -242,6 +242,20 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals assert ( args.overlap_moe_expert_parallel_comm ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + + # Schedule plan building is only supported for GPT models + # Check if this is a Mamba model + unwrapped_model = model + while hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if "Mamba" in model_class_name: + raise NotImplementedError( + "Schedule plan building is not supported for Mamba models. " + "Please disable overlap_moe_expert_parallel_comm for Mamba." + ) + if args.patch_moe_overlap: assert ( not args.delay_wgrad_compute @@ -267,8 +281,21 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals ) return schedule_plan, partial(self.loss_func, loss_mask) else: - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask - ) + # Check if model supports loss_mask parameter + # MambaModel doesn't accept loss_mask, but GPTModel does + # Unwrap the model to get the actual model class + unwrapped_model = model + while hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + model_class_name = unwrapped_model.__class__.__name__ + + if "Mamba" in model_class_name: + # MambaModel doesn't accept loss_mask parameter + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + else: + # GPTModel and other models accept loss_mask parameter + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) return output_tensor, partial(self.loss_func, loss_mask) diff --git a/primus/modules/trainer/megatron/trainer.py b/primus/modules/trainer/megatron/trainer.py index 5db59188b..8f0c9f70e 100644 --- a/primus/modules/trainer/megatron/trainer.py +++ b/primus/modules/trainer/megatron/trainer.py @@ -483,8 +483,10 @@ def update_primus_config( if args.iterations_to_skip is None: args.iterations_to_skip = [] - # support moe_freq_type - if isinstance(args.moe_layer_freq, str): + # support moe_freq_type - ensure moe_layer_freq has a default value + if not hasattr(args, "moe_layer_freq"): + args.moe_layer_freq = 1 + elif isinstance(args.moe_layer_freq, str): try: args.moe_layer_freq = eval(args.moe_layer_freq) except Exception: @@ -496,11 +498,35 @@ def update_primus_config( args.valid_data_path = None args.test_data_path = None + # Determine model type (gpt or mamba) + model_type = getattr(args, "model_type", "gpt") + log_rank_0(f"-detected model_type: {model_type}") + + # Ensure required attributes have safe defaults if missing from config + if not hasattr(args, "final_logit_softcapping"): + args.final_logit_softcapping = None + if not hasattr(args, "router_logit_softcapping"): + args.router_logit_softcapping = None + + # Only pass model_type parameter when it's "mamba" to maintain backward compatibility + # with main branch behavior for "gpt" (default) case if args.final_logit_softcapping is not None and args.final_logit_softcapping > 0.0: log_rank_0(f"-enable final_logit_softcapping: {args.final_logit_softcapping}") - self.model_provider = functools.partial(primus_model_provider, get_model_provider()) + if model_type == "mamba": + self.model_provider = functools.partial( + primus_model_provider, get_model_provider(model_type=model_type) + ) + else: + self.model_provider = functools.partial(primus_model_provider, get_model_provider()) else: - self.model_provider = get_model_provider() + if model_type == "mamba": + log_rank_0(f"-getting model provider for model_type={model_type}") + model_provider = get_model_provider(model_type=model_type) + log_rank_0(f"-model_provider: {model_provider}") + self.model_provider = model_provider + else: + # For "gpt" (default), call without arguments to match main branch behavior + self.model_provider = get_model_provider() if args.router_logit_softcapping is not None and args.router_logit_softcapping > 0.0: log_rank_0(f"-enable router_logit_softcapping: {args.router_logit_softcapping}") @@ -867,6 +893,8 @@ def setup_model_and_optimizer( log_rank_0(f"use te backend...") log_rank_0(f"-run get_model") + log_rank_0(f"-model_provider_func: {model_provider_func}") + log_rank_0(f"-model_type: {model_type}") model = get_model(model_provider_func, model_type) log_rank_0(model) # get_megatron_optimizer will use the ddp_config diff --git a/third_party/Megatron-Bridge b/third_party/Megatron-Bridge index 9577b1280..7949b3f90 160000 --- a/third_party/Megatron-Bridge +++ b/third_party/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 9577b1280eaadd60b9d7b0ce6df09ac80e87e323 +Subproject commit 7949b3f90a3e8d425a13e47f43b4c67bb57196f7 diff --git a/third_party/maxtext b/third_party/maxtext index 022dc02eb..8def32a8a 160000 --- a/third_party/maxtext +++ b/third_party/maxtext @@ -1 +1 @@ -Subproject commit 022dc02eb89057350d2e365f23c8f1f0edb4732d +Subproject commit 8def32a8a5b96fc6267636a8e58abfe4c178e161