Skip to content

Commit 47ac5c9

Browse files
committed
support qwen3_vl and qwen3_vl_moe position_ids and rope_deltas
1 parent 501c16e commit 47ac5c9

File tree

1 file changed

+87
-16
lines changed

1 file changed

+87
-16
lines changed

specforge/core/eagle3.py

Lines changed: 87 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -608,26 +608,26 @@ def forward(
608608
past_key_values_length = past_key_values[0][0].shape[2]
609609
seq_length_with_past = seq_length_with_past + past_key_values_length
610610

611-
if position_ids is None:
612-
attention_mask_tensor = (
613-
attention_mask
614-
if not isinstance(attention_mask, dict)
615-
else attention_mask["full_attention"]
611+
base_attention_mask = (
612+
attention_mask
613+
if not isinstance(attention_mask, dict)
614+
else attention_mask["full_attention"]
615+
)
616+
# Cache the raw mask so that SDPA and RoPE refresh both see the same window-aligned view.
617+
if base_attention_mask is not None and base_attention_mask.ndim == 4:
618+
base_attention_mask = torch.diagonal(
619+
base_attention_mask[:, 0], dim1=1, dim2=2
616620
)
617-
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
618-
attention_mask_tensor = torch.diagonal(
619-
attention_mask_tensor[:, 0], dim1=1, dim2=2
620-
)
621-
attention_mask_tensor = (
622-
attention_mask_tensor
623-
/ torch.finfo(attention_mask_tensor.dtype).min
624-
)
625-
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
621+
base_attention_mask = (
622+
base_attention_mask / torch.finfo(base_attention_mask.dtype).min
623+
)
624+
base_attention_mask = (1.0 - base_attention_mask).int()
626625

626+
if position_ids is None:
627627
get_rope_kwargs = {
628628
"input_ids": input_ids,
629629
"image_grid_thw": image_grid_thw,
630-
"attention_mask": attention_mask_tensor,
630+
"attention_mask": base_attention_mask,
631631
}
632632
if self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}:
633633
get_rope_kwargs["video_grid_thw"] = video_grid_thw
@@ -639,8 +639,18 @@ def forward(
639639
)
640640
if rope_deltas is not None:
641641
self.rope_deltas = rope_deltas
642+
full_attention_mask = (
643+
base_attention_mask.clone()
644+
if base_attention_mask is not None
645+
else None
646+
)
642647
else:
643648
position_ids = position_ids
649+
full_attention_mask = (
650+
base_attention_mask.clone()
651+
if base_attention_mask is not None
652+
else None
653+
)
644654

645655
# Step 4: handle attention mask
646656
if attention_mask is None:
@@ -651,7 +661,7 @@ def forward(
651661
)
652662
if self.attention_backend == "sdpa":
653663
attention_mask = self.draft_model.prepare_decoder_attention_mask(
654-
attention_mask=attention_mask,
664+
attention_mask=full_attention_mask,
655665
hidden_states=hidden_states,
656666
batch_size=batch_size,
657667
seq_length=seq_length,
@@ -715,6 +725,67 @@ def forward(
715725
input_ids = padding(input_ids, left=False)
716726
position_mask = padding(position_mask, left=False)
717727
loss_mask = padding(loss_mask, left=False)
728+
# Shrink the cached mask so SDPA keeps the same view after padding.
729+
# Roll the cached SDPA mask so it matches the new left-shifted window.
730+
if full_attention_mask is not None:
731+
full_attention_mask = padding(full_attention_mask, left=False)
732+
733+
if self.attention_backend == "sdpa":
734+
attention_mask = self.draft_model.prepare_decoder_attention_mask(
735+
attention_mask=full_attention_mask,
736+
hidden_states=hidden_states,
737+
batch_size=batch_size,
738+
seq_length=seq_length,
739+
past_key_values_length=past_key_values_length,
740+
)
741+
elif (
742+
attention_mask is not None
743+
and self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}
744+
):
745+
# qwen3 path carries the un-expanded 2D causal mask directly.
746+
attention_mask = padding(attention_mask, left=False)
747+
748+
next_attention_tensor = (
749+
full_attention_mask
750+
if full_attention_mask is not None
751+
else (
752+
attention_mask
753+
if self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}
754+
else None
755+
)
756+
)
757+
if (
758+
next_attention_tensor is not None
759+
and self.target_model_type not in {"qwen3_vl", "qwen3_vl_moe"}
760+
and next_attention_tensor.ndim == 4
761+
):
762+
# qwen2.5 still produces inverted 4D masks; collapse and flip them before RoPE.
763+
next_attention_tensor = torch.diagonal(
764+
next_attention_tensor[:, 0], dim1=1, dim2=2
765+
)
766+
if next_attention_tensor.dtype.is_floating_point:
767+
next_attention_tensor = (
768+
next_attention_tensor
769+
/ torch.finfo(next_attention_tensor.dtype).min
770+
)
771+
next_attention_tensor = (1.0 - next_attention_tensor).int()
772+
773+
# qwen3_vl expects video grid kwargs rather than second_per_grid_ts, qwen2.5_vl still needs both.
774+
rope_kwargs = {
775+
"input_ids": input_ids,
776+
"image_grid_thw": image_grid_thw,
777+
"attention_mask": next_attention_tensor,
778+
}
779+
if self.target_model_type in {"qwen3_vl", "qwen3_vl_moe"}:
780+
rope_kwargs["video_grid_thw"] = video_grid_thw
781+
else:
782+
rope_kwargs["video_grid_thw"] = video_grid_thw
783+
rope_kwargs["second_per_grid_ts"] = second_per_grid_ts
784+
position_ids, rope_deltas = self.target_model.model.get_rope_index(
785+
**rope_kwargs
786+
)
787+
if rope_deltas is not None:
788+
self.rope_deltas = rope_deltas
718789
# Flex attention mask shirnking is handled inside attention module
719790
return plosses, vlosses, acces
720791

0 commit comments

Comments
 (0)