@@ -608,26 +608,26 @@ def forward(
608608 past_key_values_length = past_key_values [0 ][0 ].shape [2 ]
609609 seq_length_with_past = seq_length_with_past + past_key_values_length
610610
611- if position_ids is None :
612- attention_mask_tensor = (
613- attention_mask
614- if not isinstance (attention_mask , dict )
615- else attention_mask ["full_attention" ]
611+ base_attention_mask = (
612+ attention_mask
613+ if not isinstance (attention_mask , dict )
614+ else attention_mask ["full_attention" ]
615+ )
616+ # Cache the raw mask so that SDPA and RoPE refresh both see the same window-aligned view.
617+ if base_attention_mask is not None and base_attention_mask .ndim == 4 :
618+ base_attention_mask = torch .diagonal (
619+ base_attention_mask [:, 0 ], dim1 = 1 , dim2 = 2
616620 )
617- if attention_mask_tensor is not None and attention_mask_tensor .ndim == 4 :
618- attention_mask_tensor = torch .diagonal (
619- attention_mask_tensor [:, 0 ], dim1 = 1 , dim2 = 2
620- )
621- attention_mask_tensor = (
622- attention_mask_tensor
623- / torch .finfo (attention_mask_tensor .dtype ).min
624- )
625- attention_mask_tensor = (1.0 - attention_mask_tensor ).int ()
621+ base_attention_mask = (
622+ base_attention_mask / torch .finfo (base_attention_mask .dtype ).min
623+ )
624+ base_attention_mask = (1.0 - base_attention_mask ).int ()
626625
626+ if position_ids is None :
627627 get_rope_kwargs = {
628628 "input_ids" : input_ids ,
629629 "image_grid_thw" : image_grid_thw ,
630- "attention_mask" : attention_mask_tensor ,
630+ "attention_mask" : base_attention_mask ,
631631 }
632632 if self .target_model_type in {"qwen3_vl" , "qwen3_vl_moe" }:
633633 get_rope_kwargs ["video_grid_thw" ] = video_grid_thw
@@ -639,8 +639,18 @@ def forward(
639639 )
640640 if rope_deltas is not None :
641641 self .rope_deltas = rope_deltas
642+ full_attention_mask = (
643+ base_attention_mask .clone ()
644+ if base_attention_mask is not None
645+ else None
646+ )
642647 else :
643648 position_ids = position_ids
649+ full_attention_mask = (
650+ base_attention_mask .clone ()
651+ if base_attention_mask is not None
652+ else None
653+ )
644654
645655 # Step 4: handle attention mask
646656 if attention_mask is None :
@@ -651,7 +661,7 @@ def forward(
651661 )
652662 if self .attention_backend == "sdpa" :
653663 attention_mask = self .draft_model .prepare_decoder_attention_mask (
654- attention_mask = attention_mask ,
664+ attention_mask = full_attention_mask ,
655665 hidden_states = hidden_states ,
656666 batch_size = batch_size ,
657667 seq_length = seq_length ,
@@ -715,6 +725,67 @@ def forward(
715725 input_ids = padding (input_ids , left = False )
716726 position_mask = padding (position_mask , left = False )
717727 loss_mask = padding (loss_mask , left = False )
728+ # Shrink the cached mask so SDPA keeps the same view after padding.
729+ # Roll the cached SDPA mask so it matches the new left-shifted window.
730+ if full_attention_mask is not None :
731+ full_attention_mask = padding (full_attention_mask , left = False )
732+
733+ if self .attention_backend == "sdpa" :
734+ attention_mask = self .draft_model .prepare_decoder_attention_mask (
735+ attention_mask = full_attention_mask ,
736+ hidden_states = hidden_states ,
737+ batch_size = batch_size ,
738+ seq_length = seq_length ,
739+ past_key_values_length = past_key_values_length ,
740+ )
741+ elif (
742+ attention_mask is not None
743+ and self .target_model_type in {"qwen3_vl" , "qwen3_vl_moe" }
744+ ):
745+ # qwen3 path carries the un-expanded 2D causal mask directly.
746+ attention_mask = padding (attention_mask , left = False )
747+
748+ next_attention_tensor = (
749+ full_attention_mask
750+ if full_attention_mask is not None
751+ else (
752+ attention_mask
753+ if self .target_model_type in {"qwen3_vl" , "qwen3_vl_moe" }
754+ else None
755+ )
756+ )
757+ if (
758+ next_attention_tensor is not None
759+ and self .target_model_type not in {"qwen3_vl" , "qwen3_vl_moe" }
760+ and next_attention_tensor .ndim == 4
761+ ):
762+ # qwen2.5 still produces inverted 4D masks; collapse and flip them before RoPE.
763+ next_attention_tensor = torch .diagonal (
764+ next_attention_tensor [:, 0 ], dim1 = 1 , dim2 = 2
765+ )
766+ if next_attention_tensor .dtype .is_floating_point :
767+ next_attention_tensor = (
768+ next_attention_tensor
769+ / torch .finfo (next_attention_tensor .dtype ).min
770+ )
771+ next_attention_tensor = (1.0 - next_attention_tensor ).int ()
772+
773+ # qwen3_vl expects video grid kwargs rather than second_per_grid_ts, qwen2.5_vl still needs both.
774+ rope_kwargs = {
775+ "input_ids" : input_ids ,
776+ "image_grid_thw" : image_grid_thw ,
777+ "attention_mask" : next_attention_tensor ,
778+ }
779+ if self .target_model_type in {"qwen3_vl" , "qwen3_vl_moe" }:
780+ rope_kwargs ["video_grid_thw" ] = video_grid_thw
781+ else :
782+ rope_kwargs ["video_grid_thw" ] = video_grid_thw
783+ rope_kwargs ["second_per_grid_ts" ] = second_per_grid_ts
784+ position_ids , rope_deltas = self .target_model .model .get_rope_index (
785+ ** rope_kwargs
786+ )
787+ if rope_deltas is not None :
788+ self .rope_deltas = rope_deltas
718789 # Flex attention mask shirnking is handled inside attention module
719790 return plosses , vlosses , acces
720791
0 commit comments