Skip to content

Commit d660579

Browse files
committed
fix pre commit
1 parent c9dbc1f commit d660579

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

specforge/core/eagle3.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def forward(
147147
cache_hidden = None
148148
past_key_values = DynamicCache()
149149
else:
150-
raise ValueError(
151-
f"Unknown attention backend: {self.attention_backend}"
152-
)
150+
raise ValueError(f"Unknown attention backend: {self.attention_backend}")
153151

154152
for idx in range(self.length):
155153
target_p = target_p_padded[:, idx : idx + seq_length, :]
@@ -524,9 +522,7 @@ def forward(
524522
cache_hidden = None
525523
past_key_values = DynamicCache()
526524
else:
527-
raise ValueError(
528-
f"Unknown attention backend: {self.attention_backend}"
529-
)
525+
raise ValueError(f"Unknown attention backend: {self.attention_backend}")
530526

531527
for idx in range(self.length):
532528
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()

specforge/modeling/draft/llama3_eagle.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7+
from flash_attn import flash_attn_func
78
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
89
from transformers import GenerationMixin, LlamaConfig, PreTrainedModel
910
from transformers.activations import ACT2FN
1011
from transformers.cache_utils import Cache
1112
from transformers.models.llama.configuration_llama import LlamaConfig
12-
from flash_attn import flash_attn_func
1313

1414
from specforge.modeling.draft.flex_attention import (
1515
compile_friendly_create_block_mask,
@@ -866,9 +866,7 @@ def forward(
866866
key_states = self.k_proj(hidden_states)
867867
value_states = self.v_proj(hidden_states)
868868

869-
query_states = query_states.view(
870-
bsz, q_len, self.num_heads, self.head_dim
871-
)
869+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
872870
key_states = key_states.view(
873871
bsz, q_len, self.num_key_value_heads, self.head_dim
874872
)
@@ -921,7 +919,13 @@ def forward(
921919

922920
lck = len(cache_k)
923921
if lck > 1:
924-
q_shape_expanded = (bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
922+
q_shape_expanded = (
923+
bsz,
924+
q_len,
925+
self.num_key_value_heads,
926+
self.num_key_value_groups,
927+
self.head_dim,
928+
)
925929
attn_outputs = [attn_output.view(q_shape_expanded)]
926930
lses = [lse.view(q_shape_expanded[:-1])]
927931

0 commit comments

Comments
 (0)