Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "specforge"
dynamic = ["version", "description"]
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"pre-commit",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision==0.23.0",
"transformers==4.57.1",
"qwen-vl-utils==0.0.11",
"datasets",
"setuptools",
"tqdm",
"wandb",
"psutil",
"numpy",
"accelerate",
"pydantic",
"sglang[all]==0.5.4",
"openai-harmony",
"flash-attn>=2.6.3",
]

[tool.setuptools]
packages = ["specforge"]

[tool.setuptools.dynamic]
version = {file = "version.txt"}
description = {file = "README.md"}

[tool.uv.extra-build-dependencies]
flash-attn = ["torch==2.8.0"]
16 changes: 0 additions & 16 deletions requirements.txt

This file was deleted.

7 changes: 4 additions & 3 deletions scripts/regenerate_train_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def main():
print("-" * 50)

# Create progress bar
with open(args.input_file_path, "r") as input_file, open(
args.output_file_path, "w"
) as output_file_handle:
with (
open(args.input_file_path, "r") as input_file,
open(args.output_file_path, "w") as output_file_handle,
):

executor = ThreadPoolExecutor(
max_workers=args.concurrency * len(valid_server_addresses)
Expand Down
18 changes: 11 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from setuptools import find_packages, setup

import tomllib
from pathlib import Path

def read_requirements():
with open(f"requirements.txt", "r") as f:
lines = (line.strip() for line in f)
return [line for line in lines if line and not line.startswith(("#", "--"))]
from setuptools import find_packages, setup


def read_readme():
Expand All @@ -17,11 +14,18 @@ def read_version():
return f.read().strip()


def read_dependencies():
pyproject_path = Path(__file__).parent / "pyproject.toml"
with open(pyproject_path, "rb") as f:
pyproject = tomllib.load(f)
return pyproject.get("project", {}).get("dependencies", [])


setup(
name="specforge",
packages=find_packages(exclude=["configs", "scripts", "tests"]),
version=read_version(),
install_requires=read_requirements(),
install_requires=read_dependencies(),
long_description=read_readme(),
long_description_content_type="text/markdown",
author="SGLang Team",
Expand Down
8 changes: 6 additions & 2 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ def forward(
plosses = []
vlosses = []
acces = []
if self.attention_backend == "sdpa":
if self.attention_backend in ["sdpa", "fa"]:
cache_hidden = [[], []]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

for idx in range(self.length):
target_p = target_p_padded[:, idx : idx + seq_length, :]
Expand Down Expand Up @@ -513,12 +515,14 @@ def forward(
plosses = []
vlosses = []
acces = []
if self.attention_backend == "sdpa":
if self.attention_backend in ["sdpa", "fa"]:
cache_hidden = [[], []]
past_key_values = None
elif self.attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
else:
raise ValueError(f"Unknown attention backend: {self.attention_backend}")

for idx in range(self.length):
target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous()
Expand Down
4 changes: 2 additions & 2 deletions specforge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def prepare_dp_dataloaders(
shuffle: Optional[bool] = False,
is_vlm: Optional[bool] = False,
prefetch_factor: Optional[int] = 2,
**dataloader_kwargs
**dataloader_kwargs,
) -> DataLoader:
"""
Prepare dataloader for distributed data parallel training.
Expand Down Expand Up @@ -264,6 +264,6 @@ def prepare_dp_dataloaders(
prefetch_factor=prefetch_factor,
collate_fn=datacollator_cls(),
drop_last=True,
**dataloader_kwargs
**dataloader_kwargs,
)
return dataloader
120 changes: 117 additions & 3 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_func
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from transformers import GenerationMixin, LlamaConfig, PreTrainedModel
from transformers.activations import ACT2FN
Expand Down Expand Up @@ -90,12 +91,12 @@ def rotate_half(x):


@torch.compile(dynamic=True)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -840,6 +841,117 @@ def forward(
return attn_output


class LlamaFlashAttention(LlamaAttention):
"""
Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention.
The used parameters are:
- hidden_states: input hidden states
- position_ids: position ids
- cache_hidden: manual cache used for storing past key and value states
"""

def forward(
self,
hidden_states: torch.Tensor,
cache_hidden: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
)

lck = 0 if cache_hidden is None else len(cache_hidden[0])
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
cos, sin = self.rotary_emb(query_states, position_ids + lck)
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
self.config.rope_scaling["mrope_section"],
unsqueeze_dim=2,
)
else:
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, unsqueeze_dim=2
)

if cache_hidden is not None:
cache_hidden[0] = cache_hidden[0] + [key_states]
cache_hidden[1] = cache_hidden[1] + [value_states]

cache_k = cache_hidden[0]
cache_v = cache_hidden[1]
else:
cache_k = [key_states]
cache_v = [value_states]

k0 = cache_k[0]
v0 = cache_v[0]

attn_output, lse, _ = flash_attn_func(
query_states,
k0,
v0,
dropout_p=0.0,
softmax_scale=1.0 / math.sqrt(self.head_dim),
causal=True,
return_attn_probs=True,
)
lse = lse.transpose(1, 2)

lck = len(cache_k)
if lck > 1:
q_shape_expanded = (
bsz,
q_len,
self.num_key_value_heads,
self.num_key_value_groups,
self.head_dim,
)
attn_outputs = [attn_output.view(q_shape_expanded)]
lses = [lse.view(q_shape_expanded[:-1])]

for i in range(1, lck):
ki = cache_k[i].unsqueeze(-2)
qi = query_states.view(q_shape_expanded)
vi = cache_v[i].unsqueeze(-2)

attn_outputs.append(vi)
lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim))

lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1)
attn_output = sum(
attn_outputi * torch.exp(lsei - lse).unsqueeze(-1)
for attn_outputi, lsei in zip(attn_outputs, lses)
)
# lse is fp32, downcast attn_output back
attn_output = attn_output.to(self.o_proj.weight.dtype)

attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads)

attn_output = self.o_proj(attn_output)

return attn_output


class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -913,6 +1025,8 @@ def __init__(self, config, attention_backend: str = "sdpa"):
elif attention_backend == "flex_attention":
print_with_rank("Using flex attention on draft model training!")
self.self_attn = LlamaFlexAttention(config=config)
elif attention_backend == "fa":
self.self_attn = LlamaFlashAttention(config=config)
else:
raise ValueError(f"Unknown attention backend {attention_backend}")

Expand Down
Loading