Skip to content

Commit 81966ba

Browse files
committed
fix draft rotary embedding when given 2d position ids
1 parent b44dc48 commit 81966ba

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

specforge/modeling/draft/llama3_eagle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor:
317317

318318
def forward(self, x, position_ids):
319319
# In contrast to other models, Qwen-VL variants have different position ids for the grids
320-
# So we expand the inv_freq to shape (3, ...)
320+
# So we expand the position ids/inv_freq to shape (3, ...)
321+
if position_ids.ndim == 2:
322+
position_ids = position_ids[None, ...].expand(3, -1, -1)
321323
inv_freq_expanded = (
322324
self.inv_freq[None, None, :, None]
323325
.float()

0 commit comments

Comments
 (0)