Skip to content

Sliding window attention for long-context training #42

@mmshad

Description

@mmshad

Context

Sliding window attention limits each token to attend to the last W previous positions instead of the full sequence. Primary benefits: attention compute drops from O(S²) to O(S·W), and KV cache during inference shrinks from O(S) to O(W). Useful for long-context training (32K+ sequences) where full attention is compute-prohibitive.

Used by Mistral, Mixtral, Gemma-2, and similar architectures — some apply sliding window on every layer, others alternate with full attention.

Current state

Attention in kempnerforge/model/attention.py always uses full causal attention via PyTorch SDPA. ModelConfig has no sliding window field. KVCache pre-allocates max_seq_len slots and grows linearly with the full sequence.

PyTorch SDPA does not have a native window_size parameter, so sliding window requires an explicit banded-causal attn_mask.

What needs to change

  1. Config: Add sliding_window: int | None = None to ModelConfig in kempnerforge/config/model.py. None keeps current behavior (full attention).

  2. Threading: Pass sliding_window through TransformerBlock into Attention.__init__(); store as self.sliding_window.

  3. Masking in Attention.forward(): When sliding_window is set, construct a banded causal mask and pass as attn_mask. When doc_ids is also present, intersect the band with the existing block-diagonal document mask — each token attends to the last W tokens within the same document only.

  4. KV cache: When sliding_window is set, KVCache only needs the last W positions. Convert to a circular buffer of size W instead of pre-allocating max_seq_len. Reduces inference memory from O(max_seq_len) to O(W).

Config example

[model]
sliding_window = 4096  # each token attends to last 4096 positions

MFU impact

kempnerforge/metrics/mfu.py uses 12 * L * D * S for the attention term in both _dense_flops_per_token and _moe_flops_per_token. When sliding_window is set, this term should use min(S, W) — otherwise MFU will be under-reported for sliding-window models.

Testing

  • Output matches full attention when W >= seq_len
  • Tokens beyond window distance receive zero attention weight
  • Works with GQA, packed sequences (doc_ids), and KV cache
  • MFU calculation uses min(S, W) when sliding window is set

Priority

Medium. Unlocks long-context training on memory-constrained hardware. Independent of FA3 upstream tracking.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions