Skip to content

Commit 9639a52

Browse files
authored
fixed kv head replication in qwen3 moe (#357)
* fixed kv head replication in qwen3 moe * poliosh * poliosh
1 parent dc44caf commit 9639a52

File tree

3 files changed

+56
-18
lines changed

3 files changed

+56
-18
lines changed

specforge/layers/linear.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def __init__(
7979
bias=True,
8080
device=None,
8181
dtype=None,
82-
kv_head_replicas=False,
8382
layout_type: str = "normal",
83+
kv_head_replicas=False,
84+
kv_head_idx=None,
85+
total_num_kv_heads=None,
8486
):
8587
super().__init__()
8688
factory_kwargs = {"device": device, "dtype": dtype}
@@ -91,7 +93,10 @@ def __init__(
9193

9294
self.in_features = in_features
9395
self.out_features = out_features
94-
if kv_head_replicas:
96+
self.kv_head_replicas = kv_head_replicas
97+
self.kv_head_idx = kv_head_idx
98+
self.total_num_kv_heads = total_num_kv_heads
99+
if self.kv_head_replicas:
95100
self.out_features_per_shard = out_features
96101
else:
97102
self.out_features_per_shard = out_features // self.tp_size
@@ -113,14 +118,33 @@ def shard_state_dict(self, state_dict, *args):
113118
"""
114119
This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type.
115120
"""
116-
if self.layout_type == "normal":
117-
self.handle_normal_layout(state_dict, *args)
118-
elif self.layout_type == "merged_qkv":
119-
self.handle_merged_qkv(state_dict, *args)
120-
elif self.layout_type == "gate_up":
121-
self.handle_gate_up_layout(state_dict, *args)
121+
if self.kv_head_replicas:
122+
assert self.kv_head_idx is not None
123+
assert self.layout_type == "normal"
124+
self.handle_kv_head_replicas(state_dict, *args)
122125
else:
123-
raise ValueError(f"Invalid layout type: {self.layout_type}")
126+
if self.layout_type == "normal":
127+
self.handle_normal_layout(state_dict, *args)
128+
elif self.layout_type == "merged_qkv":
129+
self.handle_merged_qkv(state_dict, *args)
130+
elif self.layout_type == "gate_up":
131+
self.handle_gate_up_layout(state_dict, *args)
132+
else:
133+
raise ValueError(f"Invalid layout type: {self.layout_type}")
134+
135+
def handle_kv_head_replicas(self, state_dict, *args):
136+
"""
137+
This is a special case for GQA where the key/value are split according to the number of kv heads and the head which belongs to this rank.
138+
As the TP size is larger than the number of kv heads, we only keep one kv head per rank.
139+
"""
140+
if "weight" in state_dict:
141+
state_dict["weight"] = state_dict["weight"].chunk(
142+
self.total_num_kv_heads, dim=0
143+
)[self.kv_head_idx]
144+
if "bias" in state_dict and state_dict["bias"] is not None:
145+
state_dict["bias"] = state_dict["bias"].chunk(
146+
self.total_num_kv_heads, dim=0
147+
)[self.kv_head_idx]
124148

125149
def handle_normal_layout(self, state_dict, *args):
126150
"""

specforge/modeling/target/custom_backend/qwen3_moe.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
7878
# Calculate head distribution for TP
7979
self.total_num_heads = config.num_attention_heads
8080
self.total_num_kv_heads = config.num_key_value_heads
81-
self.num_heads = self.total_num_heads // self.tp_size
81+
self.num_heads = (
82+
self.total_num_heads // self.tp_size
83+
) # this is the number heads per rank
8284

8385
# Handle KV head replication when tp_size > total_num_kv_heads
8486
if self.tp_size > self.total_num_kv_heads:
8587
# In replication mode, each rank gets 1 KV head (replicated across groups)
8688
self.num_kv_heads = 1
8789
self.num_kv_head_replicas = self.tp_size // self.total_num_kv_heads
88-
self.num_key_value_groups = self.num_heads // self.num_kv_heads
90+
self.num_key_value_groups = (
91+
self.num_heads // self.num_kv_heads
92+
) # this is size for expanding kv for gqa
8993
self.kv_head_replicas = True
9094
else:
9195
self.num_kv_heads = self.total_num_kv_heads
@@ -103,18 +107,23 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
103107
self.num_kv_heads * self.head_dim,
104108
bias=config.attention_bias,
105109
kv_head_replicas=self.kv_head_replicas,
110+
kv_head_idx=self.tp_rank // self.num_kv_head_replicas,
111+
total_num_kv_heads=config.num_key_value_heads,
106112
)
107113
self.v_proj = ColumnParallelLinear(
108114
config.hidden_size,
109115
self.num_kv_heads * self.head_dim,
110116
bias=config.attention_bias,
111117
kv_head_replicas=self.kv_head_replicas,
118+
kv_head_idx=self.tp_rank // self.num_kv_head_replicas,
119+
total_num_kv_heads=config.num_key_value_heads,
112120
)
113121
self.o_proj = RowParallelLinear(
114122
config.num_attention_heads * self.head_dim,
115123
config.hidden_size,
116124
bias=config.attention_bias,
117125
)
126+
118127
self.q_norm = Qwen3MoeRMSNorm(
119128
self.head_dim, eps=config.rms_norm_eps
120129
) # unlike olmo, only on the head dim!
@@ -193,9 +202,10 @@ def __init__(self, config, intermediate_size=None):
193202

194203
# Add TP support
195204
self.tp_group = get_tp_group()
196-
197205
self.gate_proj = ColumnParallelLinear(
198-
self.hidden_size, self.intermediate_size, bias=False
206+
self.hidden_size,
207+
self.intermediate_size,
208+
bias=False,
199209
)
200210
self.up_proj = ColumnParallelLinear(
201211
self.hidden_size, self.intermediate_size, bias=False

tests/test_modeling/test_target/test_custom_backend/test_qwen3_moe_tp.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tests.utils import get_available_port
1717

1818

19-
def test_qwen3_moe_tp(rank, world_size, temp_dir, port):
19+
def test_qwen3_moe_tp(rank, world_size, temp_dir, port, num_heads, num_kv_heads):
2020
os.environ["RANK"] = str(rank)
2121
os.environ["WORLD_SIZE"] = str(world_size)
2222
os.environ["MASTER_ADDR"] = "localhost"
@@ -33,8 +33,8 @@ def test_qwen3_moe_tp(rank, world_size, temp_dir, port):
3333
moe_intermediate_size=512,
3434
num_hidden_layers=2,
3535
max_position_embeddings=1024,
36-
num_attention_heads=8,
37-
num_key_value_heads=4,
36+
num_attention_heads=num_heads,
37+
num_key_value_heads=num_kv_heads,
3838
num_experts=64,
3939
num_experts_per_tok=8,
4040
hidden_act="silu",
@@ -93,10 +93,14 @@ def setUp(self):
9393
def tearDown(self):
9494
self.temp_dir.cleanup()
9595

96-
def test_qwen3_moe_tp(self):
96+
def test_qwen3_moe_tp_no_kv_head_replicas(self):
9797
# Set to 2 as only 2 GPU avaialble in CI
9898
port = get_available_port()
99-
mp.spawn(test_qwen3_moe_tp, nprocs=2, args=(2, self.temp_dir.name, port))
99+
mp.spawn(test_qwen3_moe_tp, nprocs=2, args=(2, self.temp_dir.name, port, 8, 4))
100+
101+
def test_qwen3_moe_tp_kv_head_replicas(self):
102+
port = get_available_port()
103+
mp.spawn(test_qwen3_moe_tp, nprocs=2, args=(2, self.temp_dir.name, port, 8, 1))
100104

101105

102106
if __name__ == "__main__":

0 commit comments

Comments
 (0)