You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
KempnerForge runs all training and inference attention through PyTorch SDPA
(F.scaled_dot_product_attention, backends flash / efficient / cudnn / math). Call sites: kempnerforge/model/attention.py:189, 193 (main
self-attention), kempnerforge/model/cross_attention.py:107 (VLM CA blocks),
and kempnerforge/model/mot.py:178 (per-modality MoT blocks).
FlexAttention (torch.nn.attention.flex_attention, available in the pinned
PyTorch 2.11) is a separate PyTorch API and is not exposed as an SDPA
backend; the two APIs are architecturally distinct (SDPA selects fused
kernels; FlexAttention generates them from a score_mod plus BlockMask
via torch.compile).
Document-causal packing: currently a manual block-diagonal attn_mask
tensor at kempnerforge/model/attention.py:182-189; a BlockMask is sparse
and JIT-fused.
Non-additive score transforms (Gemma 2's soft-cap soft_cap * tanh(score / soft_cap), per-head logit clipping): SDPA's attn_mask is added to scores before softmax (fine for additive biases
like ALiBi) but cannot express non-additive transforms. FlexAttention's score_mod can.
Orthogonal to #41 (FlashAttention-3 in SDPA). Both can coexist.
Minimum slice (v1)
Scope is the main self-attention path only. Cross-attention and MoT keep SDPA
in v1.
attention_impl: "sdpa" | "flex" on ModelConfig (default "sdpa", no
behavior change).
New Attention.forward branch in kempnerforge/model/attention.py: build a BlockMask from is_causal plus optional doc_ids, call flex_attention.
Parametrize the existing attention unit tests over both backends.
Verify FSDP2 + DTensor + bf16 under torch.compile.
Sliding window (#42) then lands as a follow-up BlockMask generator without
touching infrastructure. Extending the flag to CrossAttentionBlock and MoTBlock is a separate follow-up.
KempnerForge runs all training and inference attention through PyTorch SDPA
(
F.scaled_dot_product_attention, backendsflash/efficient/cudnn/math). Call sites:kempnerforge/model/attention.py:189, 193(mainself-attention),
kempnerforge/model/cross_attention.py:107(VLM CA blocks),and
kempnerforge/model/mot.py:178(per-modality MoT blocks).FlexAttention (
torch.nn.attention.flex_attention, available in the pinnedPyTorch 2.11) is a separate PyTorch API and is not exposed as an SDPA
backend; the two APIs are architecturally distinct (SDPA selects fused
kernels; FlexAttention generates them from a
score_modplusBlockMaskvia
torch.compile).Why
Things SDPA cannot do or does expensively:
attn_maskis O(T^2) memory; FlexAttention'sBlockMaskis sparse-aware.attn_masktensor at
kempnerforge/model/attention.py:182-189; aBlockMaskis sparseand JIT-fused.
soft_cap * tanh(score / soft_cap), per-head logit clipping): SDPA'sattn_maskis added to scores before softmax (fine for additive biaseslike ALiBi) but cannot express non-additive transforms. FlexAttention's
score_modcan.Orthogonal to #41 (FlashAttention-3 in SDPA). Both can coexist.
Minimum slice (v1)
Scope is the main self-attention path only. Cross-attention and MoT keep SDPA
in v1.
attention_impl: "sdpa" | "flex"onModelConfig(default"sdpa", nobehavior change).
Attention.forwardbranch inkempnerforge/model/attention.py: build aBlockMaskfromis_causalplus optionaldoc_ids, callflex_attention.torch.compile.Sliding window (#42) then lands as a follow-up
BlockMaskgenerator withouttouching infrastructure. Extending the flag to
CrossAttentionBlockandMoTBlockis a separate follow-up.