|
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
6 | 6 | import torch.nn.functional as F |
| 7 | +from flash_attn import flash_attn_func |
7 | 8 | from torch.nn.attention.flex_attention import create_block_mask, flex_attention |
8 | 9 | from transformers import GenerationMixin, LlamaConfig, PreTrainedModel |
9 | 10 | from transformers.activations import ACT2FN |
10 | 11 | from transformers.cache_utils import Cache |
11 | 12 | from transformers.models.llama.configuration_llama import LlamaConfig |
12 | | -from flash_attn import flash_attn_func |
13 | 13 |
|
14 | 14 | from specforge.modeling.draft.flex_attention import ( |
15 | 15 | compile_friendly_create_block_mask, |
@@ -866,9 +866,7 @@ def forward( |
866 | 866 | key_states = self.k_proj(hidden_states) |
867 | 867 | value_states = self.v_proj(hidden_states) |
868 | 868 |
|
869 | | - query_states = query_states.view( |
870 | | - bsz, q_len, self.num_heads, self.head_dim |
871 | | - ) |
| 869 | + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) |
872 | 870 | key_states = key_states.view( |
873 | 871 | bsz, q_len, self.num_key_value_heads, self.head_dim |
874 | 872 | ) |
@@ -921,7 +919,13 @@ def forward( |
921 | 919 |
|
922 | 920 | lck = len(cache_k) |
923 | 921 | if lck > 1: |
924 | | - q_shape_expanded = (bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim) |
| 922 | + q_shape_expanded = ( |
| 923 | + bsz, |
| 924 | + q_len, |
| 925 | + self.num_key_value_heads, |
| 926 | + self.num_key_value_groups, |
| 927 | + self.head_dim, |
| 928 | + ) |
925 | 929 | attn_outputs = [attn_output.view(q_shape_expanded)] |
926 | 930 | lses = [lse.view(q_shape_expanded[:-1])] |
927 | 931 |
|
|
0 commit comments