Skip to content

Commit c9dbc1f

Browse files
committed
added flash_attn backend
1 parent 2de6996 commit c9dbc1f

File tree

4 files changed

+362
-5
lines changed

4 files changed

+362
-5
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ accelerate
1414
pydantic
1515
sglang[all]==0.5.4
1616
openai-harmony
17+
flash-attn>=2.6.3

specforge/core/eagle3.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,16 @@ def forward(
140140
plosses = []
141141
vlosses = []
142142
acces = []
143-
if self.attention_backend == "sdpa":
143+
if self.attention_backend in ["sdpa", "fa"]:
144144
cache_hidden = [[], []]
145145
past_key_values = None
146146
elif self.attention_backend == "flex_attention":
147147
cache_hidden = None
148148
past_key_values = DynamicCache()
149+
else:
150+
raise ValueError(
151+
f"Unknown attention backend: {self.attention_backend}"
152+
)
149153

150154
for idx in range(self.length):
151155
target_p = target_p_padded[:, idx : idx + seq_length, :]
@@ -513,12 +517,16 @@ def forward(
513517
plosses = []
514518
vlosses = []
515519
acces = []
516-
if self.attention_backend == "sdpa":
520+
if self.attention_backend in ["sdpa", "fa"]:
517521
cache_hidden = [[], []]
518522
past_key_values = None
519523
elif self.attention_backend == "flex_attention":
520524
cache_hidden = None
521525
past_key_values = DynamicCache()
526+
else:
527+
raise ValueError(
528+
f"Unknown attention backend: {self.attention_backend}"
529+
)
522530

523531
for idx in range(self.length):
524532
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()

specforge/modeling/draft/llama3_eagle.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers.activations import ACT2FN
1010
from transformers.cache_utils import Cache
1111
from transformers.models.llama.configuration_llama import LlamaConfig
12+
from flash_attn import flash_attn_func
1213

1314
from specforge.modeling.draft.flex_attention import (
1415
compile_friendly_create_block_mask,
@@ -90,12 +91,12 @@ def rotate_half(x):
9091

9192

9293
@torch.compile(dynamic=True)
93-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
94+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
9495
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
9596
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
9697
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
97-
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98-
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98+
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
99+
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
99100
q_embed = (q * cos) + (rotate_half(q) * sin)
100101
k_embed = (k * cos) + (rotate_half(k) * sin)
101102
return q_embed, k_embed
@@ -840,6 +841,113 @@ def forward(
840841
return attn_output
841842

842843

844+
class LlamaFlashAttention(LlamaAttention):
845+
"""
846+
Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention.
847+
The used parameters are:
848+
- hidden_states: input hidden states
849+
- position_ids: position ids
850+
- cache_hidden: manual cache used for storing past key and value states
851+
"""
852+
853+
def forward(
854+
self,
855+
hidden_states: torch.Tensor,
856+
cache_hidden: Optional[List[torch.Tensor]] = None,
857+
attention_mask: Optional[torch.Tensor] = None,
858+
position_ids: Optional[torch.LongTensor] = None,
859+
past_key_values: Optional[Cache] = None,
860+
output_attentions: bool = False,
861+
use_cache: bool = False,
862+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
863+
bsz, q_len, _ = hidden_states.size()
864+
865+
query_states = self.q_proj(hidden_states)
866+
key_states = self.k_proj(hidden_states)
867+
value_states = self.v_proj(hidden_states)
868+
869+
query_states = query_states.view(
870+
bsz, q_len, self.num_heads, self.head_dim
871+
)
872+
key_states = key_states.view(
873+
bsz, q_len, self.num_key_value_heads, self.head_dim
874+
)
875+
value_states = value_states.view(
876+
bsz, q_len, self.num_key_value_heads, self.head_dim
877+
)
878+
879+
lck = 0 if cache_hidden is None else len(cache_hidden[0])
880+
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
881+
cos, sin = self.rotary_emb(query_states, position_ids + lck)
882+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
883+
query_states, key_states = apply_multimodal_rotary_pos_emb(
884+
query_states,
885+
key_states,
886+
cos,
887+
sin,
888+
self.config.rope_scaling["mrope_section"],
889+
unsqueeze_dim=2,
890+
)
891+
else:
892+
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
893+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
894+
query_states, key_states = apply_rotary_pos_emb(
895+
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2
896+
)
897+
898+
if cache_hidden is not None:
899+
cache_hidden[0] = cache_hidden[0] + [key_states]
900+
cache_hidden[1] = cache_hidden[1] + [value_states]
901+
902+
cache_k = cache_hidden[0]
903+
cache_v = cache_hidden[1]
904+
else:
905+
cache_k = [key_states]
906+
cache_v = [value_states]
907+
908+
k0 = cache_k[0]
909+
v0 = cache_v[0]
910+
911+
attn_output, lse, _ = flash_attn_func(
912+
query_states,
913+
k0,
914+
v0,
915+
dropout_p=0.0,
916+
softmax_scale=1.0 / math.sqrt(self.head_dim),
917+
causal=True,
918+
return_attn_probs=True,
919+
)
920+
lse = lse.transpose(1, 2)
921+
922+
lck = len(cache_k)
923+
if lck > 1:
924+
q_shape_expanded = (bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
925+
attn_outputs = [attn_output.view(q_shape_expanded)]
926+
lses = [lse.view(q_shape_expanded[:-1])]
927+
928+
for i in range(1, lck):
929+
ki = cache_k[i].unsqueeze(-2)
930+
qi = query_states.view(q_shape_expanded)
931+
vi = cache_v[i].unsqueeze(-2)
932+
933+
attn_outputs.append(vi)
934+
lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim))
935+
936+
lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1)
937+
attn_output = sum(
938+
attn_outputi * torch.exp(lsei - lse).unsqueeze(-1)
939+
for attn_outputi, lsei in zip(attn_outputs, lses)
940+
)
941+
# lse is fp32, downcast attn_output back
942+
attn_output = attn_output.to(self.o_proj.weight.dtype)
943+
944+
attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads)
945+
946+
attn_output = self.o_proj(attn_output)
947+
948+
return attn_output
949+
950+
843951
class LlamaMLP(nn.Module):
844952
def __init__(self, config):
845953
super().__init__()
@@ -913,6 +1021,8 @@ def __init__(self, config, attention_backend: str = "sdpa"):
9131021
elif attention_backend == "flex_attention":
9141022
print_with_rank("Using flex attention on draft model training!")
9151023
self.self_attn = LlamaFlexAttention(config=config)
1024+
elif attention_backend == "fa":
1025+
self.self_attn = LlamaFlashAttention(config=config)
9161026
else:
9171027
raise ValueError(f"Unknown attention backend {attention_backend}")
9181028

0 commit comments

Comments
 (0)