|
9 | 9 | from transformers.activations import ACT2FN |
10 | 10 | from transformers.cache_utils import Cache |
11 | 11 | from transformers.models.llama.configuration_llama import LlamaConfig |
| 12 | +from flash_attn import flash_attn_func |
12 | 13 |
|
13 | 14 | from specforge.modeling.draft.flex_attention import ( |
14 | 15 | compile_friendly_create_block_mask, |
@@ -90,12 +91,12 @@ def rotate_half(x): |
90 | 91 |
|
91 | 92 |
|
92 | 93 | @torch.compile(dynamic=True) |
93 | | -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| 94 | +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
94 | 95 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. |
95 | 96 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] |
96 | 97 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] |
97 | | - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] |
98 | | - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] |
| 98 | + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] |
| 99 | + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] |
99 | 100 | q_embed = (q * cos) + (rotate_half(q) * sin) |
100 | 101 | k_embed = (k * cos) + (rotate_half(k) * sin) |
101 | 102 | return q_embed, k_embed |
@@ -840,6 +841,113 @@ def forward( |
840 | 841 | return attn_output |
841 | 842 |
|
842 | 843 |
|
| 844 | +class LlamaFlashAttention(LlamaAttention): |
| 845 | + """ |
| 846 | + Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention. |
| 847 | + The used parameters are: |
| 848 | + - hidden_states: input hidden states |
| 849 | + - position_ids: position ids |
| 850 | + - cache_hidden: manual cache used for storing past key and value states |
| 851 | + """ |
| 852 | + |
| 853 | + def forward( |
| 854 | + self, |
| 855 | + hidden_states: torch.Tensor, |
| 856 | + cache_hidden: Optional[List[torch.Tensor]] = None, |
| 857 | + attention_mask: Optional[torch.Tensor] = None, |
| 858 | + position_ids: Optional[torch.LongTensor] = None, |
| 859 | + past_key_values: Optional[Cache] = None, |
| 860 | + output_attentions: bool = False, |
| 861 | + use_cache: bool = False, |
| 862 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 863 | + bsz, q_len, _ = hidden_states.size() |
| 864 | + |
| 865 | + query_states = self.q_proj(hidden_states) |
| 866 | + key_states = self.k_proj(hidden_states) |
| 867 | + value_states = self.v_proj(hidden_states) |
| 868 | + |
| 869 | + query_states = query_states.view( |
| 870 | + bsz, q_len, self.num_heads, self.head_dim |
| 871 | + ) |
| 872 | + key_states = key_states.view( |
| 873 | + bsz, q_len, self.num_key_value_heads, self.head_dim |
| 874 | + ) |
| 875 | + value_states = value_states.view( |
| 876 | + bsz, q_len, self.num_key_value_heads, self.head_dim |
| 877 | + ) |
| 878 | + |
| 879 | + lck = 0 if cache_hidden is None else len(cache_hidden[0]) |
| 880 | + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): |
| 881 | + cos, sin = self.rotary_emb(query_states, position_ids + lck) |
| 882 | + cos, sin = cos.to(query_states.device), sin.to(query_states.device) |
| 883 | + query_states, key_states = apply_multimodal_rotary_pos_emb( |
| 884 | + query_states, |
| 885 | + key_states, |
| 886 | + cos, |
| 887 | + sin, |
| 888 | + self.config.rope_scaling["mrope_section"], |
| 889 | + unsqueeze_dim=2, |
| 890 | + ) |
| 891 | + else: |
| 892 | + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) |
| 893 | + cos, sin = cos.to(query_states.device), sin.to(query_states.device) |
| 894 | + query_states, key_states = apply_rotary_pos_emb( |
| 895 | + query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2 |
| 896 | + ) |
| 897 | + |
| 898 | + if cache_hidden is not None: |
| 899 | + cache_hidden[0] = cache_hidden[0] + [key_states] |
| 900 | + cache_hidden[1] = cache_hidden[1] + [value_states] |
| 901 | + |
| 902 | + cache_k = cache_hidden[0] |
| 903 | + cache_v = cache_hidden[1] |
| 904 | + else: |
| 905 | + cache_k = [key_states] |
| 906 | + cache_v = [value_states] |
| 907 | + |
| 908 | + k0 = cache_k[0] |
| 909 | + v0 = cache_v[0] |
| 910 | + |
| 911 | + attn_output, lse, _ = flash_attn_func( |
| 912 | + query_states, |
| 913 | + k0, |
| 914 | + v0, |
| 915 | + dropout_p=0.0, |
| 916 | + softmax_scale=1.0 / math.sqrt(self.head_dim), |
| 917 | + causal=True, |
| 918 | + return_attn_probs=True, |
| 919 | + ) |
| 920 | + lse = lse.transpose(1, 2) |
| 921 | + |
| 922 | + lck = len(cache_k) |
| 923 | + if lck > 1: |
| 924 | + q_shape_expanded = (bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim) |
| 925 | + attn_outputs = [attn_output.view(q_shape_expanded)] |
| 926 | + lses = [lse.view(q_shape_expanded[:-1])] |
| 927 | + |
| 928 | + for i in range(1, lck): |
| 929 | + ki = cache_k[i].unsqueeze(-2) |
| 930 | + qi = query_states.view(q_shape_expanded) |
| 931 | + vi = cache_v[i].unsqueeze(-2) |
| 932 | + |
| 933 | + attn_outputs.append(vi) |
| 934 | + lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim)) |
| 935 | + |
| 936 | + lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1) |
| 937 | + attn_output = sum( |
| 938 | + attn_outputi * torch.exp(lsei - lse).unsqueeze(-1) |
| 939 | + for attn_outputi, lsei in zip(attn_outputs, lses) |
| 940 | + ) |
| 941 | + # lse is fp32, downcast attn_output back |
| 942 | + attn_output = attn_output.to(self.o_proj.weight.dtype) |
| 943 | + |
| 944 | + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) |
| 945 | + |
| 946 | + attn_output = self.o_proj(attn_output) |
| 947 | + |
| 948 | + return attn_output |
| 949 | + |
| 950 | + |
843 | 951 | class LlamaMLP(nn.Module): |
844 | 952 | def __init__(self, config): |
845 | 953 | super().__init__() |
@@ -913,6 +1021,8 @@ def __init__(self, config, attention_backend: str = "sdpa"): |
913 | 1021 | elif attention_backend == "flex_attention": |
914 | 1022 | print_with_rank("Using flex attention on draft model training!") |
915 | 1023 | self.self_attn = LlamaFlexAttention(config=config) |
| 1024 | + elif attention_backend == "fa": |
| 1025 | + self.self_attn = LlamaFlashAttention(config=config) |
916 | 1026 | else: |
917 | 1027 | raise ValueError(f"Unknown attention backend {attention_backend}") |
918 | 1028 |
|
|
0 commit comments