Skip to content

Add FlexAttention as an alternative attention backend #97

@mmshad

Description

@mmshad

KempnerForge runs all training and inference attention through PyTorch SDPA
(F.scaled_dot_product_attention, backends flash / efficient / cudnn /
math). Call sites: kempnerforge/model/attention.py:189, 193 (main
self-attention), kempnerforge/model/cross_attention.py:107 (VLM CA blocks),
and kempnerforge/model/mot.py:178 (per-modality MoT blocks).

FlexAttention (torch.nn.attention.flex_attention, available in the pinned
PyTorch 2.11) is a separate PyTorch API and is not exposed as an SDPA
backend; the two APIs are architecturally distinct (SDPA selects fused
kernels; FlexAttention generates them from a score_mod plus BlockMask
via torch.compile).

Why

Things SDPA cannot do or does expensively:

  • Sliding window (Sliding window attention for long-context training #42): dense attn_mask is O(T^2) memory; FlexAttention's
    BlockMask is sparse-aware.
  • Document-causal packing: currently a manual block-diagonal attn_mask
    tensor at kempnerforge/model/attention.py:182-189; a BlockMask is sparse
    and JIT-fused.
  • Multi-image VLM masks: per-image attention scopes.
  • Non-additive score transforms (Gemma 2's soft-cap
    soft_cap * tanh(score / soft_cap), per-head logit clipping): SDPA's
    attn_mask is added to scores before softmax (fine for additive biases
    like ALiBi) but cannot express non-additive transforms. FlexAttention's
    score_mod can.

Orthogonal to #41 (FlashAttention-3 in SDPA). Both can coexist.

Minimum slice (v1)

Scope is the main self-attention path only. Cross-attention and MoT keep SDPA
in v1.

  • attention_impl: "sdpa" | "flex" on ModelConfig (default "sdpa", no
    behavior change).
  • New Attention.forward branch in kempnerforge/model/attention.py: build a
    BlockMask from is_causal plus optional doc_ids, call flex_attention.
  • Parametrize the existing attention unit tests over both backends.
  • Verify FSDP2 + DTensor + bf16 under torch.compile.

Sliding window (#42) then lands as a follow-up BlockMask generator without
touching infrastructure. Extending the flag to CrossAttentionBlock and
MoTBlock is a separate follow-up.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions