From d4e2b5690993c2254bba78e3e283c928910797d5 Mon Sep 17 00:00:00 2001 From: amazloumi Date: Mon, 18 May 2026 16:44:59 -0400 Subject: [PATCH 1/4] Add MoMa (Mixture of Modality-Aware Experts) VLM architecture --- README.md | 6 +- configs/train/vlm_7b_moma.toml | 125 ++++ configs/train/vlm_debug_moma.toml | 90 +++ kempnerforge/config/job.py | 27 +- kempnerforge/config/vlm.py | 124 ++++ kempnerforge/model/moma.py | 424 ++++++++++++++ kempnerforge/model/transformer.py | 50 +- kempnerforge/model/vlm.py | 45 ++ tests/distributed/test_vlm_moma_fsdp.py | 344 +++++++++++ tests/integration/test_vlm_moma.py | 318 ++++++++++ tests/unit/test_moma.py | 737 ++++++++++++++++++++++++ tests/unit/test_vlm_config.py | 97 +++- 12 files changed, 2382 insertions(+), 5 deletions(-) create mode 100644 configs/train/vlm_7b_moma.toml create mode 100644 configs/train/vlm_debug_moma.toml create mode 100644 kempnerforge/model/moma.py create mode 100644 tests/distributed/test_vlm_moma_fsdp.py create mode 100644 tests/integration/test_vlm_moma.py create mode 100644 tests/unit/test_moma.py diff --git a/README.md b/README.md index 7b5e267..f34acbd 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,7 @@ KempnerForge supports VLM training via a thin wrapper around the existing `Trans - **Joint-Decoder** (`arch = "joint_decoder"`): image embeds are prepended to the text sequence; the transformer runs over the concatenated `(image, text)` sequence and the LM head is applied to text positions only. - **Cross-Attention** (`arch = "cross_attention"`, Llama-3-V style): the residual stream carries text only. Separate `CrossAttentionBlock`s inserted at a configurable cadence let text queries attend to image K/V. CA blocks are zero-initialized so adding the arch on top of a text-only checkpoint is identity at step 0 and learns from there. - **Mixture-of-Transformers** (`arch = "mot"`, Liang et al. 2024 Algorithm 1): every layer carries per-modality Q/K/V/O projections plus a per-modality FFN; a single global self-attention mixes all modality streams. Image tokens prepend the text sequence (same residual layout as Joint-Decoder); per-modality residual projections are zero-initialized so a fresh MoT block is identity at construction. A warm-start helper (`mot_warm_start_from_text_stack`) translates a JD or text-only checkpoint into per-modality copies — toggle via `[model.vlm].mot_warm_start_from_text` + `mot_warm_start_path`. +- **Mixture of Modality-Aware Experts** (`arch = "moma"`, Lin et al. 2024 arXiv:2407.21770): one shared set of Q/K/V/O projections feeding a global self-attention, plus per-modality MoE FFN groups (paper's optimal default 4 image + 4 text experts per layer). Tokens route deterministically to their modality group (level-1, reusing the same `modality_ids` mechanism MoT uses) and then through a learned expert-choice + Sigmoid router within the group (level-2, with Gumbel-Sigmoid noise during training, paper Eq. 5). Image tokens prepend the text sequence (same residual layout as JD/MoT). v1 supports training only — expert-choice routing is non-causal, so autoregressive generation requires auxiliary routers (paper §2.4) which are deferred to a follow-up. ```bash # 1-GPU smoke (random encoder, Joint-Decoder) @@ -160,9 +161,12 @@ uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_cross_a # 4-GPU 7B Mixture-of-Transformers uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_mot.toml + +# 4-GPU 7B Mixture of Modality-Aware Experts (4 text + 4 image experts per layer) +uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml ``` -Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); multi-image and video are reserved slots for follow-up work. +Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Multi-image and video are reserved slots for follow-up work. **Adding a new VLM arch.** The discriminated-union dispatch is registry-driven, so a new arch is four small additions, no edits to existing call sites: diff --git a/configs/train/vlm_7b_moma.toml b/configs/train/vlm_7b_moma.toml new file mode 100644 index 0000000..f7a8c58 --- /dev/null +++ b/configs/train/vlm_7b_moma.toml @@ -0,0 +1,125 @@ +# 7B VLM Mixture of Modality-Aware Experts (MoMa) on 4x H200 (141 GB/GPU). +# +# Mode: smoke / starter. +# +# Shared Q/K/V/O attention (one set across modalities) + per-modality MoE +# FFN groups with expert-choice + Sigmoid routing (Lin et al. 2024, +# arXiv:2407.21770). Default 4 text + 4 image experts per layer +# (paper's optimal moe_4t4i). Tokens route deterministically by modality +# (level 1) then through a learned EC + Sigmoid router within their +# modality group (level 2). Image tokens prepend the text sequence +# (same residual layout as Joint-Decoder / MoT); output_slice trims +# them off the LM head input. +# +# Inference note: MoMa v1 supports training only. Expert-choice routing +# is non-causal (each expert's top-k depends on all tokens in the batch); +# autoregressive generation requires auxiliary routers (paper §2.4), +# deferred to a follow-up. +# +# Parameter / memory note: with the default 7B-dense-shaped backbone +# (dim=4096, n_layers=32, ffn ~14336) and 8 SwiGLU experts per layer +# (4 image + 4 text), total params is much larger than dense 7B. Use +# FSDP=4 + activation_checkpointing="full" to fit on 4x H200. For a +# roomier setup, reduce moma_experts_per_modality (e.g. 2t2i) or fall +# back to the MoT debug config. Pipeline Parallel + VLM is not supported. +# +# max_seq_len allocation: residual_image_tokens + max_text_len. Image +# tokens prepend the text sequence in the residual stream, so the budget +# must cover both modalities. +# +# Usage: +# uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml +# +# Default points at a 30-sample COCO val substitute (sayakpaul/coco-30-val-2014) +# so a fresh clone runs without external setup. For real training override: +# --data.hf_dataset_name=/path/to/local/hf_dataset \ +# --data.hf_dataset_text_field=caption +# Swap the encoder via [vision_encoder].type = "siglip2" / "clip". + +[model] +dim = 4096 +n_layers = 32 +n_heads = 32 +n_kv_heads = 8 +vocab_size = 50257 +ffn_dim_multiplier = 1.3 +norm_type = "rmsnorm" +activation = "silu" +max_seq_len = 1024 # 256 image + 512 text (max_text_len default) + headroom +rope_theta = 500000.0 +tie_embeddings = false + +[vision_encoder] +type = "random" +feature_dim = 1024 +num_tokens = 256 + +[vlm] +arch = "moma" +# Paper Eq. 5: Gumbel-Sigmoid noise on router scores during training. +# Set false for a deterministic forward (useful for warm-start parity +# checks and reproducibility-sensitive smoke runs). +moma_gumbel_noise = true +# moma_capacity_factor = 0.0 → use paper default 1/|E^M| per modality +# (each expert sees the average token load; perfect EC balance). + +[vlm.moma_experts_per_modality] +# Paper's optimal balanced allocation (moe_4t4i at 1.4B-compute-matched +# in Table 1). Unbalanced allocations like {image = 1, text = 7} are +# supported and match the paper's moe_7t1i / 1t7i ablations. +image = 4 +text = 4 + +[train] +batch_size = 4 +seq_len = 768 +max_steps = 200 +grad_accum_steps = 1 +grad_clip_norm = 1.0 +seed = 42 +# Modality-aware scatter/gather + EC top-k are not yet validated under +# torch.compile (graph breaks on data-dependent dispatch); JobConfig.validate +# emits a warning if you flip this on. +compile_model = false +# Required given the per-layer expert duplication on 4x H200; drop to +# "selective" or "none" if you scale down experts_per_modality or n_layers. +activation_checkpointing = "full" +loss_fn = "cross_entropy" + +[optimizer] +name = "adamw" +lr = 1e-4 +weight_decay = 0.1 +betas = [0.9, 0.95] +eps = 1e-8 +fused = true + +[scheduler] +name = "cosine" +warmup_steps = 5 +min_lr_ratio = 0.1 + +[data] +hf_dataset_name = "sayakpaul/coco-30-val-2014" +hf_dataset_split = "train" +hf_dataset_image_field = "image" +hf_dataset_text_field = "caption" +hf_image_size = 224 +tokenizer_path = "gpt2" +num_workers = 2 +pin_memory = true +prefetch_factor = 2 + +[distributed] +dp_shard = -1 +nccl_timeout_sec = 600 + +[checkpoint] +dir = "checkpoints/vlm_7b_moma" +interval = 1000 +keep_last_n = 1 + +[metrics] +log_interval = 1 +enable_wandb = false +enable_tensorboard = false diff --git a/configs/train/vlm_debug_moma.toml b/configs/train/vlm_debug_moma.toml new file mode 100644 index 0000000..5575fb6 --- /dev/null +++ b/configs/train/vlm_debug_moma.toml @@ -0,0 +1,90 @@ +# VLM smoke config — Mixture of Modality-Aware Experts (MoMa) arch, +# tiny LLM + random vision encoder. +# +# Mode: smoke. +# +# Runs end-to-end in <2 minutes on 1 GPU. Uses RandomVisionEncoder so no +# HF download is needed. Shared Q/K/V/O self-attention + per-modality +# MoE FFN groups with expert-choice + Sigmoid routing (Lin et al. 2024, +# arXiv:2407.21770). 2 text experts + 2 image experts at every layer. +# +# max_seq_len allocation: residual_image_tokens + max_text_len. Image +# tokens prepend the text sequence in the residual stream (same layout +# as Joint-Decoder / MoT). +# +# Note: MoMa v1 supports training only. Expert-choice routing is +# non-causal; autoregressive generation requires auxiliary routers +# (paper §2.4), deferred to a follow-up. +# +# Usage: +# uv run python scripts/train.py configs/train/vlm_debug_moma.toml \ +# --data.hf_dataset_name=... --data.tokenizer_path=gpt2 + +[model] +dim = 256 +n_layers = 4 +n_heads = 4 +n_kv_heads = 4 +vocab_size = 50257 # gpt2 vocab +max_seq_len = 576 # 64 image + 512 text +norm_type = "rmsnorm" +activation = "silu" + +[vision_encoder] +type = "random" +feature_dim = 384 +num_tokens = 64 + +[vlm] +arch = "moma" + +[vlm.moma_experts_per_modality] +image = 2 +text = 2 + +[train] +batch_size = 2 +seq_len = 576 +max_steps = 50 +grad_accum_steps = 1 +grad_clip_norm = 1.0 +seed = 42 +compile_model = false +activation_checkpointing = "none" + +[optimizer] +name = "adamw" +lr = 3e-4 +weight_decay = 0.1 +betas = [0.9, 0.95] +fused = false + +[scheduler] +name = "cosine" +warmup_steps = 5 +min_lr_ratio = 0.1 + +[data] +# 30-sample COCO val substitute (sayakpaul/coco-30-val-2014). For a real +# training run, override via CLI: --data.hf_dataset_name= +hf_dataset_name = "sayakpaul/coco-30-val-2014" +hf_dataset_split = "train" +hf_dataset_image_field = "image" +hf_dataset_text_field = "caption" +hf_image_size = 224 +tokenizer_path = "gpt2" +num_workers = 2 +pin_memory = true + +[distributed] +dp_shard = -1 + +[checkpoint] +dir = "checkpoints/vlm_debug_moma" +interval = 25 +keep_last_n = 2 + +[metrics] +log_interval = 5 +enable_wandb = false +enable_tensorboard = false diff --git a/kempnerforge/config/job.py b/kempnerforge/config/job.py index 540034d..5880228 100644 --- a/kempnerforge/config/job.py +++ b/kempnerforge/config/job.py @@ -16,7 +16,7 @@ from kempnerforge.config.scheduler import SchedulerConfig from kempnerforge.config.training import TrainConfig from kempnerforge.config.vision import VisionEncoderConfig -from kempnerforge.config.vlm import VLMConfig +from kempnerforge.config.vlm import MoMaConfig, VLMConfig # Vision-encoder types whose builders load a HuggingFace model and probe # feature_dim / num_tokens from the model's config. Setting these knobs @@ -221,3 +221,28 @@ def validate(self, world_size: int = 1) -> None: # the build + forward + backward path. CrossAttentionBlocks # themselves remain dense MLP; MoE lives in the text # TransformerBlocks (where moe_frequency selects). + + if isinstance(self.vlm, MoMaConfig): + # MoMa carries its own per-modality expert counts on + # ``moma_experts_per_modality``; ``model.num_experts`` + # would be a redundant second source of truth, so we + # reject it explicitly rather than silently ignoring. + if self.model.num_experts > 0: + raise ValueError( + "MoMa + model.num_experts > 0 is rejected. MoMa derives expert " + "counts per modality from vlm.moma_experts_per_modality; set " + "model.num_experts=0 (the dense-LLM default) when using arch='moma'." + ) + if self.distributed.ep > 1: + raise ValueError( + "MoMa + Expert Parallelism is not supported in v1. Per-modality " + "expert groups need EP-aware dispatch that is not yet wired." + ) + if self.train.compile_model: + import logging + + logging.getLogger(__name__).warning( + "torch.compile is not yet validated for MoMa dispatch " + "(modality_ids-based scatter/gather + expert-choice top-k cause " + "graph breaks). Set compile_model=false for MoMa models." + ) diff --git a/kempnerforge/config/vlm.py b/kempnerforge/config/vlm.py index 3e009ad..777c27f 100644 --- a/kempnerforge/config/vlm.py +++ b/kempnerforge/config/vlm.py @@ -17,6 +17,11 @@ cross-attention blocks at a configurable cadence. - ``"mot"`` Mixture-of-Transformers: per-modality Q/K/V/O + per- modality FFN at every layer, single global self-attention. +- ``"moma"`` Mixture of Modality-Aware Experts: shared Q/K/V/O + + per-modality MoE FFN groups at every layer. Tokens are routed + deterministically by modality (level 1) then by a learned + expert-choice + Sigmoid router within their modality group + (level 2). Lin et al. 2024 (arXiv:2407.21770). Each arch gets its own ``VLMConfig`` subclass, registered via ``registry.register_vlm_config``. The TOML loader dispatches on @@ -353,3 +358,122 @@ def resolved_image_heads( n_heads = self.mot_image_n_heads or model_n_heads n_kv_heads = self.mot_image_n_kv_heads or model_n_kv_heads or n_heads return n_heads, n_kv_heads + + +@registry.register_vlm_config("moma") +@dataclass +class MoMaConfig(VLMConfig): + """Mixture of Modality-Aware Experts (MoMa): shared self-attention + + per-modality MoE FFN groups (Lin et al. 2024, arXiv:2407.21770). + + Each transformer layer is a pre-norm block with: + + - Standard ``Attention`` (one set of Q/K/V/O across modalities) running a + single global SDPA over the concatenated image+text sequence. + - A ``MoMaFFN`` that routes tokens in two stages: + + 1. Deterministic by modality (level 1): token's ``modality_ids`` value + selects which modality expert group processes it. + 2. Learned expert-choice + Sigmoid (level 2): within the modality + group, each expert independently picks its top-k tokens by sigmoid + score (with optional Gumbel-Sigmoid noise during training; paper + Eq. 5). Token output is the sum of selected experts' outputs + weighted by their sigmoid scores. + + Image tokens are prepended to the text sequence (same residual layout as + Joint-Decoder and MoT). ``modality_ids`` tags every position; the FFN + uses these tags for scatter/gather dispatch (works for arbitrary + interleaved layouts, not just image-prefix). + + Differs from ``"mot"``: MoT has per-modality Q/K/V/O *and* per-modality + FFN. MoMa has shared Q/K/V/O and per-modality MoE FFN groups (multiple + experts per modality, learned routing within each group). + + Inference note: expert-choice routing is non-causal (each expert's + top-k depends on all tokens in the batch). v1 supports training only; + autoregressive generation requires auxiliary routers (paper §2.4), + deferred to a follow-up. + + The MoMa-specific module alias ``"moma"`` is added to + ``module_patterns`` so freeze targeting works out of the box: + ``FreezeSpec("moma", True)`` freezes the per-modality MoE stack + (``transformer.layers.*``) without touching the embedding, output head, + or final norm. + """ + + arch: str = "moma" + moma_modalities: tuple[str, ...] = ("image", "text") + moma_experts_per_modality: dict[str, int] = field( + default_factory=lambda: {"image": 4, "text": 4} + ) + moma_capacity_factor: float = 0.0 + moma_gumbel_noise: bool = True + module_patterns: dict[str, list[str]] = field( + default_factory=lambda: { + **{k: list(v) for k, v in DEFAULT_MODULE_PATTERNS.items()}, + "moma": [ + "transformer.layers", + "transformer.layers.*", + ], + } + ) + + def __post_init__(self) -> None: + super().__post_init__() + if len(self.moma_modalities) < 2: + raise ValueError( + f"vlm.moma_modalities must have at least 2 entries (got {self.moma_modalities!r})" + ) + if "text" not in self.moma_modalities: + raise ValueError( + f"vlm.moma_modalities must include 'text' (got {self.moma_modalities!r})" + ) + if "image" not in self.moma_modalities: + raise ValueError( + f"vlm.moma_modalities must include 'image' (got {self.moma_modalities!r})" + ) + if len(set(self.moma_modalities)) != len(self.moma_modalities): + raise ValueError( + f"vlm.moma_modalities must not contain duplicates (got {self.moma_modalities!r})" + ) + missing = set(self.moma_modalities) - set(self.moma_experts_per_modality.keys()) + if missing: + raise ValueError( + f"vlm.moma_experts_per_modality missing entries for {sorted(missing)} " + f"(got {self.moma_experts_per_modality!r}, need keys for all " + f"moma_modalities {self.moma_modalities!r})" + ) + extra = set(self.moma_experts_per_modality.keys()) - set(self.moma_modalities) + if extra: + raise ValueError( + f"vlm.moma_experts_per_modality has unknown modality keys {sorted(extra)} " + f"(allowed: {sorted(self.moma_modalities)})" + ) + for m, n in self.moma_experts_per_modality.items(): + if n <= 0: + raise ValueError( + f"vlm.moma_experts_per_modality[{m!r}] must be positive " + f"(got {n}). For dense per-modality FFN use arch='mot' instead." + ) + if self.moma_capacity_factor < 0: + raise ValueError( + f"vlm.moma_capacity_factor must be >= 0 (got {self.moma_capacity_factor})" + ) + + def residual_stream_image_tokens(self, num_tokens: int) -> int: + """MoMa prepends ``num_tokens`` image tokens to the text sequence + (same residual-stream layout as Joint-Decoder). + """ + return num_tokens + + def effective_capacity_factor(self, modality: str) -> float: + """Resolve the per-expert capacity factor for ``modality``. + + Paper default (``moma_capacity_factor == 0``): return + ``1 / |E^M|`` so each expert sees the average load per modality + (perfect balance under expert-choice routing). Explicit positive + values pass through unchanged. + """ + if self.moma_capacity_factor > 0: + return self.moma_capacity_factor + return 1.0 / self.moma_experts_per_modality[modality] diff --git a/kempnerforge/model/moma.py b/kempnerforge/model/moma.py new file mode 100644 index 0000000..dfbbcfc --- /dev/null +++ b/kempnerforge/model/moma.py @@ -0,0 +1,424 @@ +"""Mixture of Modality-Aware Experts (MoMa) operator, FFN, and block. + +Implements Lin et al. 2024 ("MoMa: Efficient Early-Fusion Pre-training with +Mixture of Modality-Aware Experts", arXiv:2407.21770) on top of KempnerForge's +existing VLM stack. + +Architecture at a glance, per transformer layer: + +- Pre-norm ``Attention`` (the standard module, shared Q/K/V/O across + modalities) running a single global SDPA over the concatenated image+text + sequence. +- Pre-norm ``MoMaFFN``: a ``ModuleDict`` of per-modality ``ExpertChoiceMoE`` + groups dispatched by ``modality_ids``. Each group's MoE uses + Expert-Choice + Sigmoid routing (paper §2.2): each expert independently + picks its top-``k_e`` tokens by sigmoid score, and the token output is + the sum across experts that selected it, weighted by their sigmoid + scores (Eq. 1). Optional Gumbel-Sigmoid noise during training (Eq. 5). + +Differs from MoT (also in this codebase): MoT has *per-modality* Q/K/V/O +projections and per-modality FFN. MoMa has *shared* Q/K/V/O and per-modality +*MoE* FFN groups (multiple experts per modality, learned routing within +each group). Both share the residual-stream layout (image tokens prepended +to text) and ``modality_ids`` tagging mechanism. + +The module exposes four public symbols: + +- ``ExpertChoiceSigmoidRouter`` — per-modality gate (``W_g^M``), Sigmoid + scoring, optional Gumbel noise, and per-expert top-``k_e`` token + selection. +- ``ExpertChoiceMoE`` — composes a router + ``num_experts`` SwiGLU + experts; forward(x) returns the sigmoid-weighted expert combination. +- ``MoMaFFN`` — holds one ``ExpertChoiceMoE`` per modality and dispatches + tokens via ``modality_ids``. +- ``MoMaBlock`` — pre-norm block: shared ``Attention`` + ``MoMaFFN``. + +Inference note: expert-choice routing is non-causal (each expert's +top-``k_e`` depends on all tokens in the batch). v1 supports training only; +autoregressive generation requires auxiliary routers (paper §2.4), deferred +to a follow-up PR. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn + +from kempnerforge.config.schema import ModelConfig +from kempnerforge.model.attention import Attention +from kempnerforge.model.mlp import build_mlp +from kempnerforge.model.norm import build_norm + + +def _gumbel_like(x: torch.Tensor) -> torch.Tensor: + """Sample Gumbel(0, 1) noise with the same shape and dtype as ``x``. + + Uses the standard inverse-CDF trick on a clamped uniform sample. The + clamps avoid ``log(0)`` from drawing exactly zero (rare but possible + at low precisions) or ``log(1)`` (where ``-log(u) == 0`` makes the + outer ``log`` blow up). The intermediate parenthesization is load- + bearing: ``(-torch.log(u)).clamp_min(...)`` clamps the *positive* + side, whereas ``-torch.log(u).clamp_min(...)`` would clamp the + negative ``log(u)`` first (collapsing everything to ``-1e-9`` and + then triggering ``log`` of a negative number = NaN). + """ + u = torch.rand_like(x).clamp_min(1e-9) + return -torch.log((-torch.log(u)).clamp_min(1e-9)) + + +class ExpertChoiceSigmoidRouter(nn.Module): + """Expert-Choice + Sigmoid router for one modality group (Lin et al. 2024 §2.2). + + Scoring: ``score = sigmoid(W_g x)`` per token-expert pair (independent + across experts because Sigmoid does not normalize). Optional Gumbel + perturbation during training: ``Gumbel-Sigmoid(x) = sigmoid(x + G' - G'')`` + with independent Gumbel(0, 1) samples ``G', G''`` (paper Eq. 5). + + Selection: each expert independently picks its top-``k_e`` tokens by + score (``torch.topk`` on the (expert, token) score matrix). This is + the inverse of token-choice routing: there a token picks experts; + here an expert picks tokens. A token can be picked by 0, 1, or more + experts (the residual stream carries the unmodified token through + when no expert picks it). + + ``capacity_factor`` controls ``k_e`` as ``k_e = ceil(c_e * N)`` where + ``N`` is the number of tokens of this modality in the current batch. + The paper's default ``c_e = 1/|E^M|`` gives ``k_e ≈ N/|E^M|`` so each + expert sees the average load (perfect balance under EC routing). + """ + + def __init__( + self, + dim: int, + num_experts: int, + capacity_factor: float, + gumbel_noise: bool = True, + ) -> None: + super().__init__() + if num_experts <= 0: + raise ValueError( + f"ExpertChoiceSigmoidRouter: num_experts must be positive (got {num_experts})" + ) + if capacity_factor <= 0: + raise ValueError( + "ExpertChoiceSigmoidRouter: capacity_factor must be positive " + f"(got {capacity_factor})" + ) + self.gate = nn.Linear(dim, num_experts, bias=False) + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.gumbel_noise = gumbel_noise + # Tracked for metrics / debugging (analogous to MoEMLP.expert_counts). + self.expert_counts: torch.Tensor = torch.zeros(num_experts) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Route tokens to experts via expert-choice. + + Args: + x: ``(N, D)`` token representations for one modality group. + + Returns: + ``topk_scores``: ``(E, k_e)`` sigmoid scores of the tokens each + expert selected. + ``topk_indices``: ``(E, k_e)`` token indices into ``x`` that + each expert selected. ``k_e`` is computed from + ``capacity_factor * N``, capped by ``N``. + """ + if x.dim() != 2: + raise ValueError( + f"ExpertChoiceSigmoidRouter.forward expects (N, D); got shape {tuple(x.shape)}" + ) + n_tokens, _ = x.shape + if n_tokens == 0: + # Empty modality slice — return empty selections; caller handles. + empty_scores = x.new_zeros(self.num_experts, 0) + empty_indices = torch.zeros(self.num_experts, 0, dtype=torch.long, device=x.device) + return empty_scores, empty_indices + + logits = self.gate(x) # (N, E) + if self.training and self.gumbel_noise: + logits = logits + _gumbel_like(logits) - _gumbel_like(logits) + scores = torch.sigmoid(logits) # (N, E), independent per expert + + k_e = max(1, math.ceil(self.capacity_factor * n_tokens)) + k_e = min(k_e, n_tokens) + + # scores.t(): (E, N). For each expert (row), select the top-k_e tokens. + topk_scores, topk_indices = torch.topk(scores.t(), k=k_e, dim=1) + + # Per-expert utilization metric: how many tokens this expert handled. + # Always k_e (EC routing guarantees this), but recording for parity + # with the MoEMLP API. + with torch.no_grad(): + counts = torch.full( + (self.num_experts,), float(k_e), device=x.device, dtype=torch.float32 + ) + self.expert_counts = counts.detach() + + return topk_scores, topk_indices + + +class ExpertChoiceMoE(nn.Module): + """Expert-Choice MoE for one modality group. + + Composes an ``ExpertChoiceSigmoidRouter`` with ``num_experts`` SwiGLU + expert MLPs. Forward: each expert selects top-``k_e`` tokens, runs its + MLP on those tokens, and contributes ``sigmoid_score * MLP(x)`` to the + output. Tokens not picked by any expert receive zero contribution from + this MoE block (the outer residual skip preserves them). + + State-dict layout (FQN-stable): + + .. code:: + + router.gate.weight # (num_experts, dim) — gate Linear + experts.0.gate_proj.weight + experts.0.up_proj.weight + experts.0.down_proj.weight + experts.1... + ... + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + capacity_factor: float, + activation: str = "silu", + gumbel_noise: bool = True, + ) -> None: + super().__init__() + if num_experts <= 0: + raise ValueError(f"ExpertChoiceMoE: num_experts must be positive (got {num_experts})") + self.dim = dim + self.num_experts = num_experts + self.router = ExpertChoiceSigmoidRouter( + dim=dim, + num_experts=num_experts, + capacity_factor=capacity_factor, + gumbel_noise=gumbel_noise, + ) + self.experts = nn.ModuleList( + [ + build_mlp(dim=dim, hidden_dim=hidden_dim, activation=activation) + for _ in range(num_experts) + ] + ) + + @property + def expert_counts(self) -> torch.Tensor: + """Per-expert token count from the most recent forward (metrics).""" + return self.router.expert_counts + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Expert-choice MoE forward over one modality group. + + Args: + x: ``(N, D)`` token representations. + + Returns: + ``(N, D)`` output where each token has accumulated weighted + outputs from every expert that selected it (zero contribution + from this block when no expert selected the token). + """ + if x.dim() != 2: + raise ValueError(f"ExpertChoiceMoE.forward expects (N, D); got shape {tuple(x.shape)}") + n_tokens, _ = x.shape + if n_tokens == 0: + return x # Pass through empty modality slice. + + topk_scores, topk_indices = self.router(x) # (E, k_e) each + + out = torch.zeros_like(x) + # Sequential per-expert dispatch (the codebase's MoEMLP fallback + # uses the same sequential loop pattern). Grouped-GEMM EC dispatch + # is a future optimization once the operator is stable. + for e in range(self.num_experts): + token_idx = topk_indices[e] # (k_e,) + token_scores = topk_scores[e] # (k_e,) + x_e = x.index_select(0, token_idx) # (k_e, D) + out_e = self.experts[e](x_e) # (k_e, D) + weighted = token_scores.unsqueeze(-1) * out_e # (k_e, D) + # index_add (non-in-place) accumulates contributions from + # multiple experts that picked the same token. Autograd-safe + # for non-unique indices. + out = out.index_add(0, token_idx, weighted) + return out + + +class MoMaFFN(nn.Module): + """Per-modality MoE FFN groups dispatched by ``modality_ids``. + + Holds one ``ExpertChoiceMoE`` per modality (keys: modality name). + Forward dispatches tokens by ``modality_ids`` (level-1 deterministic + routing), runs each modality's EC-MoE (level-2 learned routing), + then scatters per-modality outputs back to their original positions. + + Modality index convention: ``self.modalities[i]`` corresponds to + ``modality_ids == i``. With the default ``("image", "text")``, + ``modality_ids == 0`` selects the image expert group and + ``modality_ids == 1`` selects the text expert group. + """ + + def __init__( + self, + config: ModelConfig, + modalities: tuple[str, ...], + experts_per_modality: dict[str, int], + capacity_factor_per_modality: dict[str, float], + gumbel_noise: bool = True, + ) -> None: + super().__init__() + if not modalities: + raise ValueError("MoMaFFN requires at least one modality") + missing_experts = set(modalities) - set(experts_per_modality.keys()) + if missing_experts: + raise ValueError( + f"MoMaFFN: experts_per_modality missing entries for {sorted(missing_experts)}" + ) + missing_cap = set(modalities) - set(capacity_factor_per_modality.keys()) + if missing_cap: + raise ValueError( + f"MoMaFFN: capacity_factor_per_modality missing entries for {sorted(missing_cap)}" + ) + self.modalities = tuple(modalities) + self.experts = nn.ModuleDict( + { + m: ExpertChoiceMoE( + dim=config.dim, + hidden_dim=config.computed_ffn_hidden_dim, + num_experts=experts_per_modality[m], + capacity_factor=capacity_factor_per_modality[m], + activation=config.activation, + gumbel_noise=gumbel_noise, + ) + for m in self.modalities + } + ) + + def forward(self, x: torch.Tensor, modality_ids: torch.Tensor) -> torch.Tensor: + """Dispatch tokens by modality and run per-modality EC-MoE. + + Args: + x: ``(B, S, D)`` residual stream. + modality_ids: ``(B, S)`` long tensor. ``modality_ids == i`` + routes that token to ``self.modalities[i]``'s expert + group. + + Returns: + ``(B, S, D)`` tensor with each modality's positions filled + by its EC-MoE output. Positions whose modality has no tokens + assigned by any expert get zeros (the outer residual skip + preserves them). + """ + if x.dim() != 3: + raise ValueError(f"MoMaFFN.forward expects (B, S, D); got shape {tuple(x.shape)}") + if modality_ids.dim() != 2 or modality_ids.shape != x.shape[:2]: + raise ValueError( + f"MoMaFFN.forward: modality_ids shape {tuple(modality_ids.shape)} does not " + f"match (B, S) = {tuple(x.shape[:2])}" + ) + if modality_ids.dtype != torch.long: + raise ValueError( + f"MoMaFFN.forward: modality_ids dtype must be torch.long (got {modality_ids.dtype})" + ) + + b, s, d = x.shape + x_flat = x.reshape(b * s, d) + mod_flat = modality_ids.reshape(b * s) + out = torch.zeros_like(x_flat) + + for i, m in enumerate(self.modalities): + # nonzero() avoids the boolean-mask copy and gives us a 1-D index + # tensor we can feed to index_select + scatter. + idx = (mod_flat == i).nonzero(as_tuple=False).squeeze(-1) # (N_m,) + if idx.numel() == 0: + continue + x_m = x_flat.index_select(0, idx) # (N_m, D) + y_m = self.experts[m](x_m) # (N_m, D) + # The modality groups partition the position space, so indices + # are guaranteed unique across iterations. index_copy on + # disjoint indices is safe and autograd-friendly. + out = out.index_copy(0, idx, y_m) + return out.view(b, s, d) + + +class MoMaBlock(nn.Module): + """Pre-norm transformer block: shared ``Attention`` + ``MoMaFFN``. + + Operates on a single residual tensor ``(B, S, D)`` like the dense + ``TransformerBlock`` (unlike ``MoTBlock`` which operates on a + per-modality dict of streams). The only structural difference from + ``TransformerBlock`` is the FFN: ``MoMaFFN`` instead of a dense MLP + or a flat MoE. + + State-dict layout: + + .. code:: + + attention_norm.weight + attention.q_proj.weight + attention.k_proj.weight + attention.v_proj.weight + attention.o_proj.weight + # qk_norm only: + attention.q_norm.weight + attention.k_norm.weight + mlp_norm.weight + mlp.experts.{m}.router.gate.weight + mlp.experts.{m}.experts.0.gate_proj.weight + mlp.experts.{m}.experts.0.up_proj.weight + mlp.experts.{m}.experts.0.down_proj.weight + mlp.experts.{m}.experts.1... + ... + """ + + def __init__( + self, + config: ModelConfig, + modalities: tuple[str, ...], + experts_per_modality: dict[str, int], + capacity_factor_per_modality: dict[str, float], + gumbel_noise: bool, + layer_idx: int, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.modalities = tuple(modalities) + + self.attention_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps) + self.attention = Attention( + dim=config.dim, + n_heads=config.n_heads, + n_kv_heads=config.n_kv_heads, # type: ignore[reportArgumentType] + head_dim=config.head_dim, + qk_norm=config.qk_norm, + sdpa_backend=config.sdpa_backend, + ) + self.mlp_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps) + self.mlp = MoMaFFN( + config, + modalities=self.modalities, + experts_per_modality=experts_per_modality, + capacity_factor_per_modality=capacity_factor_per_modality, + gumbel_noise=gumbel_noise, + ) + + def forward( + self, + x: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + modality_ids: torch.Tensor, + *, + doc_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + # Pre-norm attention with residual (shared QKVO, single SDPA). + # kv_cache is intentionally omitted: EC routing is non-causal in v1. + x = x + self.attention(self.attention_norm(x), rope_cos, rope_sin, doc_ids=doc_ids) + # Pre-norm MoMa FFN with residual (per-modality EC-MoE groups). + x = x + self.mlp(self.mlp_norm(x), modality_ids=modality_ids) + return x diff --git a/kempnerforge/model/transformer.py b/kempnerforge/model/transformer.py index 818568f..ffe87b8 100644 --- a/kempnerforge/model/transformer.py +++ b/kempnerforge/model/transformer.py @@ -16,7 +16,7 @@ from kempnerforge.config.registry import registry from kempnerforge.config.schema import ModelConfig -from kempnerforge.config.vlm import CrossAttentionConfig, MoTConfig, VLMConfig +from kempnerforge.config.vlm import CrossAttentionConfig, MoMaConfig, MoTConfig, VLMConfig from kempnerforge.model.attention import Attention, KVCache from kempnerforge.model.cross_attention import CrossAttentionBlock from kempnerforge.model.embedding import OutputHead, TokenEmbedding @@ -24,6 +24,7 @@ from kempnerforge.model.mlp import build_mlp from kempnerforge.model.modality import ModalityContext from kempnerforge.model.moe import MoEMLP, build_moe +from kempnerforge.model.moma import MoMaBlock from kempnerforge.model.mot import MoTBlock from kempnerforge.model.norm import build_norm from kempnerforge.model.position import precompute_rope_frequencies @@ -116,11 +117,16 @@ def __init__( # MoT branch: build MoTBlocks instead of TransformerBlocks. v1 # enforces equal head counts across modalities (single global # SDPA over the concatenated multi-modality sequence). + # MoMa branch: build MoMaBlocks (shared Q/K/V/O attention + + # per-modality MoE FFN groups). The branches are mutually + # exclusive on layer construction; CA layers are still attached + # separately below for ``CrossAttentionConfig``. # num_image_tokens flows in from the vision encoder via the VLM # build path; it is unused for non-MoT arches but kept as a single # constructor arg so the signature is uniform across arches. self._mot_modalities: tuple[str, ...] = () self._mot_n_image: int = 0 + self._moma_modalities: tuple[str, ...] = () if isinstance(vlm_config, MoTConfig): text_n_kv_heads = config.n_kv_heads if config.n_kv_heads is not None else config.n_heads img_n_heads, img_n_kv_heads = vlm_config.resolved_image_heads( @@ -140,6 +146,25 @@ def __init__( for i in range(config.n_layers) } ) + elif isinstance(vlm_config, MoMaConfig): + self._moma_modalities = vlm_config.moma_modalities + experts_per_modality = dict(vlm_config.moma_experts_per_modality) + capacity_factor_per_modality = { + m: vlm_config.effective_capacity_factor(m) for m in self._moma_modalities + } + self.layers = nn.ModuleDict( + { + str(i): MoMaBlock( + config, + modalities=self._moma_modalities, + experts_per_modality=experts_per_modality, + capacity_factor_per_modality=capacity_factor_per_modality, + gumbel_noise=vlm_config.moma_gumbel_noise, + layer_idx=i, + ) + for i in range(config.n_layers) + } + ) else: # Transformer blocks — use ModuleDict to preserve FQNs for DCP self.layers = nn.ModuleDict( @@ -351,6 +376,27 @@ def forward( cos = self._rope_cos[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript] sin = self._rope_sin[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript] + # MoMa path: single residual stream + shared SDPA + per-modality + # MoE FFN groups. modality_ids tags every position and the + # ``MoMaFFN`` uses these tags to dispatch tokens to per-modality + # expert groups (level-1 deterministic routing); within each + # group, expert-choice + Sigmoid routing picks experts + # (level-2 learned routing). EC routing is non-causal, so + # ``kv_caches`` is rejected upstream (training-only in v1). + if self._moma_modalities: + if modality_ids is None: + raise ValueError( + "MoMa model requires modality.modality_ids (got None). Build the " + "ModalityContext via MoMaStrategy or set modality_ids explicitly." + ) + if modality_ids.shape != h.shape[:2]: + raise ValueError( + f"modality.modality_ids shape {tuple(modality_ids.shape)} does not " + f"match residual shape {tuple(h.shape[:2])}" + ) + for layer in self.layers.values(): + h = layer(h, cos, sin, modality_ids, doc_ids=doc_ids) + h = self.norm(h) # MoT path: position-based image-then-text split, per-modality # streams through the MoTBlock stack, single global SDPA per # layer. modality_ids is required (presence + shape checked @@ -358,7 +404,7 @@ def forward( # tags are validated for shape but not value-matched against # positions, so a future per-token scatter/gather can land # without changing the public interface. - if self._mot_modalities: + elif self._mot_modalities: if modality_ids is None: raise ValueError( "MoT model requires modality.modality_ids (got None). Build the " diff --git a/kempnerforge/model/vlm.py b/kempnerforge/model/vlm.py index 892d19a..51ac879 100644 --- a/kempnerforge/model/vlm.py +++ b/kempnerforge/model/vlm.py @@ -23,6 +23,9 @@ Joint-Decoder (image-then-text concat, ``output_slice`` trims image positions before the head), plus a per-position ``modality_ids`` tag that the ``MoTBlock`` stack consumes for routing. +- ``"moma"`` — Mixture of Modality-Aware Experts. Same residual layout + and ``modality_ids`` tagging as MoT; per-layer block has shared + Q/K/V/O attention but per-modality MoE FFN groups. ``inner_transformer(model)`` is the explicit unwrap helper used by the training loop when it needs to reach Transformer-internal state @@ -178,6 +181,48 @@ def num_image_tokens(self, wrapper: VLMWrapper) -> int: return wrapper.vision_encoder.num_tokens +@registry.register_modality_strategy("moma") +class MoMaStrategy: + """Mixture of Modality-Aware Experts: same residual-stream layout as + Joint-Decoder/MoT (image embeds prepended, ``output_slice`` trims them + before the LM head), plus a per-position ``modality_ids`` tag the + MoMa FFN stack consumes for true scatter/gather dispatch (level-1 + deterministic routing by modality). + + Forward path: ``feats = vision_encoder(pixel_values)``; + ``img_embeds = adapter(feats)``; + ``ModalityContext(prefix_embeds, output_slice, modality_ids)``. + + Convention: ``modality_ids == 0`` for image positions and + ``modality_ids == 1`` for text positions, matching the index order + of ``MoMaConfig.moma_modalities = ("image", "text")``. The MoMa + FFN uses these tags to dispatch tokens to per-modality expert + groups; positions are *not* assumed to be in any particular order, + so interleaved layouts work too (image-prefix is just one + instantiation). + """ + + def prepare( + self, + wrapper: VLMWrapper, + pixel_values: torch.Tensor, + input_ids: torch.Tensor, + ) -> ModalityContext: + img_embeds = _project_image_features(wrapper, pixel_values) + n = wrapper.vision_encoder.num_tokens + b, t_text = input_ids.shape + modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device) + modality_ids[:, n:] = 1 + return ModalityContext( + prefix_embeds=img_embeds, + output_slice=slice(n, None), + modality_ids=modality_ids, + ) + + def num_image_tokens(self, wrapper: VLMWrapper) -> int: + return wrapper.vision_encoder.num_tokens + + def build_modality_strategy(vlm: VLMConfig) -> ModalityStrategy: """Resolve ``vlm.arch`` to its registered ``ModalityStrategy``. diff --git a/tests/distributed/test_vlm_moma_fsdp.py b/tests/distributed/test_vlm_moma_fsdp.py new file mode 100644 index 0000000..7799195 --- /dev/null +++ b/tests/distributed/test_vlm_moma_fsdp.py @@ -0,0 +1,344 @@ +"""Distributed tests for the VLM Mixture of Modality-Aware Experts (MoMa) FSDP2 wrap. + +Run with: + torchrun --nproc_per_node=2 -m pytest \\ + tests/distributed/test_vlm_moma_fsdp.py -v + +Mirrors ``tests/distributed/test_vlm_mot_fsdp.py`` for the MoMa arch: + +- Forward + backward on a 2-GPU sharded ``VLMWrapper`` with MoMa. +- Variable-length text + collator (rank consistency). +- ``test_fsdp_unfreeze_grad_flows_moma``: requires_grad mid-train flip + under FSDP2 actually re-enables gradient flow on the per-layer MoMa + stack (mandatory pre-merge test of the FreezeStage hook semantics). +- ``test_moma_two_runs_bitwise_equal_under_fsdp``: same-seed determinism + in eval mode with Gumbel noise off (the only deterministic regime for + EC + Sigmoid routing; train mode is stochastic by design). +- ``test_moma_runs_and_learns_under_fsdp``: VLM(MoMa) end-to-end on a + fixed mini-batch over a few steps; CE decreases. Proves that EC + routing + per-modality expert groups produce useful gradient signal. +- DCP checkpoint round-trip preserving MoMa-specific state-dict keys + (mlp.experts.{modality}.router.gate, mlp.experts.{modality}.experts.{i}) + + canonical vlm_freeze metadata. + +Inference-path tests are intentionally omitted from v1: EC routing is +non-causal in v1 (deferred auxiliary routers, paper §2.4). +``torch.compile`` is warned by JobConfig.validate and skipped here. +``set_moe_step`` / ``get_moe_aux_loss`` are silent no-ops for MoMa +layers (EC has no bias schedule and no aux loss), which is the +documented v1 behavior; we don't pin it as a separate test. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist + +from kempnerforge.checkpoint.manager import CheckpointManager +from kempnerforge.config.adapter import AdapterConfig +from kempnerforge.config.model import ModelConfig +from kempnerforge.config.schema import CheckpointConfig, OptimizerConfig +from kempnerforge.config.vision import VisionEncoderConfig +from kempnerforge.config.vlm import FreezeSpec, MoMaConfig +from kempnerforge.distributed.parallel import build_parallel_model +from kempnerforge.model.vlm import VLMWrapper +from kempnerforge.training.freeze import ( + apply_freeze_specs, + canonical_freeze_meta, + effective_freeze, +) +from kempnerforge.training.loss import cross_entropy_loss +from kempnerforge.training.optimizer import build_optimizer + +pytestmark = pytest.mark.skipif( + "RANK" not in os.environ, + reason="Requires torchrun launcher (RANK not set)", +) + + +def _tiny_moma_cfg( + *, + num_image_tokens: int = 8, + feature_dim: int = 96, + n_layers: int = 2, + freeze: list[FreezeSpec] | None = None, + experts_per_modality: dict[str, int] | None = None, + gumbel_noise: bool = False, +) -> tuple[ModelConfig, VisionEncoderConfig, AdapterConfig, MoMaConfig]: + return ( + ModelConfig( + dim=64, + n_layers=n_layers, + n_heads=4, + vocab_size=256, + max_seq_len=128, + ffn_hidden_dim=128, + ), + VisionEncoderConfig(type="random", feature_dim=feature_dim, num_tokens=num_image_tokens), + AdapterConfig(), + MoMaConfig( + max_text_len=32, + moma_experts_per_modality=( + experts_per_modality + if experts_per_modality is not None + else {"image": 2, "text": 2} + ), + moma_gumbel_noise=gumbel_noise, + freeze=freeze if freeze is not None else [FreezeSpec("vision_encoder", True)], + ), + ) + + +def _build( + cfg: tuple[ModelConfig, VisionEncoderConfig, AdapterConfig, MoMaConfig], + mesh, + *, + param_dtype: torch.dtype = torch.bfloat16, +) -> VLMWrapper: + mc, vc, ac, lc = cfg + model = build_parallel_model( + mc, + device=torch.device("cuda"), + device_mesh=mesh, + vision_config=vc, + adapter_config=ac, + vlm_config=lc, + param_dtype=param_dtype, + ) + real = model._orig_mod if hasattr(model, "_orig_mod") else model # type: ignore[attr-defined] + assert isinstance(real, VLMWrapper) + return model # type: ignore[return-value] + + +def _dummy_batch( + wrapper, batch: int = 2, text_len: int = 16, *, seed_offset: int = 0 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + rank = dist.get_rank() if dist.is_initialized() else 0 + pixel_gen = torch.Generator(device="cpu").manual_seed(2000 + rank + seed_offset) + pixels = torch.randn(batch, 3, 32, 32, generator=pixel_gen).cuda() + id_gen = torch.Generator(device="cpu").manual_seed(1000 + rank + seed_offset) + ids = torch.randint(0, 256, (batch, text_len), generator=id_gen).cuda() + labels = ids.clone() + return pixels, ids, labels + + +class TestBuildAndForward: + def test_build_2gpu_runs(self, distributed_env): + from kempnerforge.model.moma import MoMaBlock + + mesh = distributed_env + wrapper = _build(_tiny_moma_cfg(), mesh) + real = wrapper._orig_mod if hasattr(wrapper, "_orig_mod") else wrapper + # MoMa uses the JD/MoT image-prefix residual layout. + assert real.num_image_tokens == 8 + # Layers are MoMaBlocks. + assert all(isinstance(layer, MoMaBlock) for layer in real.transformer.layers.values()) + + pixels, ids, labels = _dummy_batch(wrapper) + logits, labels_out = wrapper(pixels, ids, labels) + assert logits.shape == (2, 16, 256) + assert labels_out is labels + + def test_fsdp_sharded_grads_flow(self, distributed_env): + mesh = distributed_env + wrapper = _build(_tiny_moma_cfg(), mesh) + # MoMa's depth-scaled init on o_proj / down_proj keeps weights small + # but nonzero, so gradients flow without an explicit re-init step + # (unlike MoT's identity-at-construction zero-init). + pixels, ids, labels = _dummy_batch(wrapper) + logits, labels_out = wrapper(pixels, ids, labels) + loss = cross_entropy_loss(logits, labels_out) + loss.backward() + for name, p in wrapper.named_parameters(): + if p.requires_grad: + assert p.grad is not None, f"trainable {name} got no grad" + for name, p in wrapper.vision_encoder.named_parameters(): + assert p.grad is None, f"frozen encoder param {name} got a grad" + + +class TestVariableLengthRankConsistency: + def test_fixed_length_pad_keeps_ranks_in_sync(self, distributed_env): + """Two ranks feeding different logical text lengths but the same + max_text_len batch shape -> NCCL collectives stay well-formed.""" + mesh = distributed_env + wrapper = _build(_tiny_moma_cfg(), mesh) + rank = dist.get_rank() + logical_len = 5 if rank == 0 else 30 + max_text_len = 32 + pixels = torch.randn(2, 3, 32, 32, device="cuda") + ids = torch.zeros(2, max_text_len, dtype=torch.long, device="cuda") + labels = torch.full((2, max_text_len), -100, dtype=torch.long, device="cuda") + ids[:, :logical_len] = torch.arange(1, logical_len + 1, device="cuda").expand(2, -1) + labels[:, :logical_len] = ids[:, :logical_len] + + logits, labels_out = wrapper(pixels, ids, labels) + loss = cross_entropy_loss(logits, labels_out) + loss.backward() + t = torch.tensor([float(loss.item())], device="cuda") + dist.all_reduce(t) + assert torch.isfinite(t).all() + + +class TestFreezeStageUnderFsdp: + def test_fsdp_unfreeze_grad_flows_moma(self, distributed_env): + """Mid-training flip of requires_grad under FSDP2 must re-enable + gradient flow on the per-layer MoMa stack. Mandatory merge gate + for the FreezeStage hook semantics. + """ + mesh = distributed_env + cfg = _tiny_moma_cfg(freeze=[FreezeSpec("moma", True)]) # main stack frozen + wrapper = _build(cfg, mesh) + + # Step 0: confirm moma stack frozen -> no grads on transformer.layers.*. + pixels, ids, labels = _dummy_batch(wrapper) + logits, labels_out = wrapper(pixels, ids, labels) + loss = cross_entropy_loss(logits, labels_out) + loss.backward() + for n, p in wrapper.named_parameters(): + if n.startswith("transformer.layers."): + assert p.grad is None, f"frozen {n} got a grad" + + # Unfreeze the main stack (simulating a FreezeStage transition). + apply_freeze_specs( + wrapper, + [FreezeSpec("moma", False)], + cfg[3].module_patterns, # type: ignore[union-attr] + ) + for p in wrapper.parameters(): + p.grad = None + pixels, ids, labels = _dummy_batch(wrapper, seed_offset=100) + logits, labels_out = wrapper(pixels, ids, labels) + loss = cross_entropy_loss(logits, labels_out) + loss.backward() + # Post-unfreeze: every trainable per-layer param has a grad + # allocated. (EC routing may leave individual experts unselected + # in a given batch — they'd get zero grads but still have a grad + # tensor allocated by FSDP2.) + for n, p in wrapper.named_parameters(): + if n.startswith("transformer.layers.") and p.requires_grad: + assert p.grad is not None, ( + f"FSDP2 did not re-attach grad-allocation hooks on requires_grad flip; " + f"FreezeStage transitions need a fully_shard rebuild on this PyTorch version. " + f"Param: {n}" + ) + + +class TestDeterminism: + def test_moma_two_runs_bitwise_equal_under_fsdp(self, distributed_env): + """Same seed, two builds, same eval-mode forward. + + EC routing with Gumbel noise is stochastic by design in train + mode; this test pins determinism only in the regime that + actually has it: eval mode + ``moma_gumbel_noise=False``. Catches + CUDA-stream races or non-deterministic dispatch in the + modality-aware scatter/gather under FSDP2 prefetch. + """ + mesh = distributed_env + torch.manual_seed(0) + wrapper_a = _build(_tiny_moma_cfg(gumbel_noise=False), mesh, param_dtype=torch.float32) + wrapper_a.eval() + pixels, ids, _ = _dummy_batch(wrapper_a, batch=1, text_len=8) + with torch.no_grad(): + logits_a, _ = wrapper_a(pixels, ids) + loss_a = float(logits_a.sum().item()) + + torch.manual_seed(0) + wrapper_b = _build(_tiny_moma_cfg(gumbel_noise=False), mesh, param_dtype=torch.float32) + wrapper_b.eval() + with torch.no_grad(): + logits_b, _ = wrapper_b(pixels, ids) + loss_b = float(logits_b.sum().item()) + + assert loss_a == loss_b, f"non-deterministic: {loss_a} vs {loss_b}" + + +class TestEcRoutingTraining: + def test_moma_runs_and_learns_under_fsdp(self, distributed_env): + """VLM(MoMa) runs end-to-end under FSDP2 at 2 GPU. CE decreases on + a fixed mini-batch over a handful of steps, proving that EC + + Sigmoid routing produces useful gradient signal through the + modality-aware scatter/gather. + + Gumbel noise is off so the EC selection is deterministic for a + given (model, batch) and the loss trajectory is comparable. With + Gumbel on, learning still works but step-to-step variance can + mask a small CE decrease in such a short run. + """ + torch.manual_seed(42 + dist.get_rank()) + mc, vc, ac, lc = _tiny_moma_cfg(n_layers=4, num_image_tokens=4, gumbel_noise=False) + # Bump dim/ffn so the FFN is meaningful. + mc.dim = 128 + mc.n_heads = 4 + mc.n_kv_heads = 4 + mc.ffn_hidden_dim = 256 + wrapper = _build((mc, vc, ac, lc), distributed_env, param_dtype=torch.float32) + + opt = build_optimizer(wrapper, OptimizerConfig(lr=3e-3, fused=False)) + loss_fn = torch.nn.CrossEntropyLoss() + + pixels = torch.randn(2, 3, 32, 32, device="cuda") + ids = torch.randint(0, mc.vocab_size, (2, 16), device="cuda") + + losses = [] + for _ in range(8): + logits, _ = wrapper(pixels, ids, ids) + ce = loss_fn(logits.reshape(-1, mc.vocab_size), ids.reshape(-1)) + ce.backward() + opt.step() + opt.zero_grad() + losses.append(ce.item()) + + initial_ce, final_ce = losses[0], losses[-1] + assert final_ce < initial_ce, ( + f"FSDP+MoMa: CE did not decrease ({initial_ce:.3f} -> {final_ce:.3f}); " + f"trajectory={losses}" + ) + + +class TestCheckpointRoundtrip: + def test_save_load_freeze_metadata_moma(self, distributed_env, shared_tmp_dir): + """Save a MoMa VLM checkpoint, load it in a fresh manager, and + verify metadata.json carries the canonical vlm_freeze and DCP + shards round-trip the per-modality MoE state-dict keys. + """ + mesh = distributed_env + cfg = _tiny_moma_cfg() + wrapper = _build(cfg, mesh) + opt = build_optimizer(wrapper, OptimizerConfig(lr=1e-3, fused=False)) + + # shared_tmp_dir lives on the shared filesystem so DCP shards + # written by rank 0 are visible to rank 1 under multi-node srun. + path_str = shared_tmp_dir + rank = dist.get_rank() + + ckpt_cfg = CheckpointConfig(dir=str(path_str), interval=1) + mgr = CheckpointManager(ckpt_cfg, wrapper, opt) + freeze = canonical_freeze_meta( + effective_freeze(0, cfg[3].freeze, cfg[3].freeze_schedule) # type: ignore[union-attr] + ) + mgr.save(step=1, extra={"vlm_freeze": freeze}) + dist.barrier() + + if rank == 0: + import json + + meta = json.loads((Path(path_str) / "step_1" / "metadata.json").read_text()) + assert meta["vlm_freeze"] == freeze + + wrapper2 = _build(cfg, mesh) + opt2 = build_optimizer(wrapper2, OptimizerConfig(lr=1e-3, fused=False)) + mgr2 = CheckpointManager(ckpt_cfg, wrapper2, opt2) + step, _, _ = mgr2.load(path=str(path_str) + "/step_1", vlm_freeze_expected=freeze) + assert step == 1 + # Per-modality MoE keys survived round-trip: gate (router) + experts. + per_modality_keys = [ + n + for n, _ in wrapper2.named_parameters() + if "mlp.experts." in n + and ("router.gate" in n or "experts." in n.split("mlp.experts.", 1)[1]) + ] + assert len(per_modality_keys) > 0, "MoMa per-modality MoE params missing after DCP load" diff --git a/tests/integration/test_vlm_moma.py b/tests/integration/test_vlm_moma.py new file mode 100644 index 0000000..7ad3440 --- /dev/null +++ b/tests/integration/test_vlm_moma.py @@ -0,0 +1,318 @@ +"""Integration tests for the VLM Mixture of Modality-Aware Experts (MoMa) training path. + +Single-GPU forward + backward on a tiny synthetic MoMa config. Exercises +``build_parallel_model`` MoMa branch, freeze targeting at the ``moma`` +alias, dtype propagation, save/load round-trip, per-modality dispatch +correctness, unbalanced expert counts (paper's moe_7t1i style), and the +``FreezeStage`` hook semantics on the per-layer MoMa stack. + +Mirrors ``tests/integration/test_vlm_mot.py`` for the MoMa arch with two +key differences: + +- No warm-start helper (deferred to v2 for MoMa). +- No ``compile`` parity test (MoMa's modality_ids scatter/gather + EC + top-k produce graph breaks; ``JobConfig.validate`` warns rather than + rejects, and the compile path is not part of v1's contract). +- MoMa is intrinsically MoE — every layer has per-modality expert + groups, so there is no "MoMa + MoE smoke" cross-test; learning under + EC routing is covered by the FSDP test in tests/distributed. + +Runs on CUDA only; skipped when ``torch.cuda.is_available()`` is False. +""" + +from __future__ import annotations + +import pytest +import torch + +from kempnerforge.config.adapter import AdapterConfig +from kempnerforge.config.model import ModelConfig +from kempnerforge.config.vision import VisionEncoderConfig +from kempnerforge.config.vlm import ( + FreezeSpec, + FreezeStage, + MoMaConfig, +) +from kempnerforge.distributed.parallel import build_parallel_model +from kempnerforge.model.vlm import VLMWrapper +from kempnerforge.training.freeze import apply_freeze_specs, effective_freeze + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="VLM MoMa integration tests require CUDA", +) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def _tiny_moma_configs( + *, + num_image_tokens: int = 8, + feature_dim: int = 96, + n_layers: int = 2, + freeze: list[FreezeSpec] | None = None, + experts_per_modality: dict[str, int] | None = None, + capacity_factor: float = 0.0, + gumbel_noise: bool = False, +) -> tuple[ModelConfig, VisionEncoderConfig, AdapterConfig, MoMaConfig]: + return ( + ModelConfig( + dim=64, + n_layers=n_layers, + n_heads=4, + vocab_size=256, + max_seq_len=128, + ffn_hidden_dim=128, + ), + VisionEncoderConfig(type="random", feature_dim=feature_dim, num_tokens=num_image_tokens), + AdapterConfig(), + MoMaConfig( + max_text_len=32, + moma_experts_per_modality=( + experts_per_modality + if experts_per_modality is not None + else {"image": 2, "text": 2} + ), + moma_capacity_factor=capacity_factor, + moma_gumbel_noise=gumbel_noise, + freeze=freeze if freeze is not None else [FreezeSpec("vision_encoder", True)], + ), + ) + + +def _build( + configs: tuple[ModelConfig, VisionEncoderConfig, AdapterConfig, MoMaConfig], + *, + param_dtype: torch.dtype = torch.bfloat16, +) -> VLMWrapper: + mc, vc, ac, lc = configs + model = build_parallel_model( + mc, + device=DEVICE, + device_mesh=None, + vision_config=vc, + adapter_config=ac, + vlm_config=lc, + param_dtype=param_dtype, + ) + assert isinstance(model, VLMWrapper) + return model + + +def _dummy_batch( + wrapper: VLMWrapper, batch: int = 2, text_len: int = 16 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pixels = torch.randn(batch, 3, 32, 32, device=DEVICE) + input_ids = torch.randint( + 0, wrapper.transformer.config.vocab_size, (batch, text_len), device=DEVICE + ) + labels = input_ids.clone() + return pixels, input_ids, labels + + +class TestBuildAndForward: + def test_build_and_forward_1gpu(self): + """Tiny MoMa config builds on a single GPU; forward + backward run.""" + from kempnerforge.model.moma import MoMaBlock + + cfg = _tiny_moma_configs() + wrapper = _build(cfg) + assert isinstance(wrapper, VLMWrapper) + # MoMa uses the JD/MoT image-prefix residual layout. + assert wrapper.num_image_tokens == 8 + # All layers are MoMaBlocks. + assert all(isinstance(layer, MoMaBlock) for layer in wrapper.transformer.layers.values()) + + pixels, input_ids, labels = _dummy_batch(wrapper) + logits, _ = wrapper(pixels, input_ids, labels) + # output_slice trims the 8 image positions; logits cover text_len positions. + assert logits.shape == (2, 16, cfg[0].vocab_size) + # Backward — depth-scaled init on o_proj / down_proj leaves nonzero + # values (unlike MoT's identity-at-construction), so gradients flow + # without per-modality re-init. + loss = logits.float().sum() + loss.backward() + adapter_grads = [ + p.grad + for n, p in wrapper.named_parameters() + if n.startswith("adapter") and p.requires_grad + ] + layer_grads = [ + p.grad + for n, p in wrapper.named_parameters() + if n.startswith("transformer.layers.") and p.requires_grad + ] + assert all(g is not None for g in adapter_grads) + assert all(g is not None for g in layer_grads) + + def test_freeze_targets_moma(self): + """``FreezeSpec("moma")`` freezes only the per-layer MoMa stack + (transformer.layers.*) and leaves the adapter, final norm, embedding, + and output head trainable. + """ + cfg = _tiny_moma_configs(freeze=[FreezeSpec("moma", True)]) + wrapper = _build(cfg) + trainable = {name for name, p in wrapper.named_parameters() if p.requires_grad} + frozen = {name for name, p in wrapper.named_parameters() if not p.requires_grad} + # All transformer.layers params are frozen. + for n in frozen: + if n.startswith("transformer."): + assert n.startswith("transformer.layers."), n + # transformer.norm (final norm) is trainable for MoMa (unlike MoT, which + # replaces it with mot_norms and freezes self.norm). + assert "transformer.norm.weight" in trainable + # Adapter is trainable. + assert any(n.startswith("adapter") for n in trainable) + # Sanity: at least one MoE gate is frozen. + assert any("mlp.experts" in n and "router.gate" in n for n in frozen) + + def test_image_features_dtype_propagation(self): + """Encoder fp32 -> adapter bf16 -> Transformer.forward casts + prefix_embeds to the residual-stream dtype before the MoMa block + sees it. Asserts (a) build with bf16 param_dtype works, (b) + forward output is bf16, (c) no dtype-mismatch errors. + """ + cfg = _tiny_moma_configs() + wrapper = _build(cfg, param_dtype=torch.bfloat16) + assert wrapper.adapter.proj1.weight.dtype == torch.bfloat16 + pixels, input_ids, _ = _dummy_batch(wrapper) + logits, _ = wrapper(pixels, input_ids) + assert logits.dtype == torch.bfloat16 + + def test_save_load_forward_parity(self): + """state_dict round-trips with bit-equal forward output. + + Uses fp32 + Gumbel-off + eval mode to make the forward strictly + deterministic so a bit-exact comparison is meaningful. + """ + cfg = _tiny_moma_configs(gumbel_noise=False) + wrapper_a = _build(cfg, param_dtype=torch.float32) + wrapper_a.eval() + pixels, input_ids, _ = _dummy_batch(wrapper_a, batch=1, text_len=8) + with torch.no_grad(): + logits_a, _ = wrapper_a(pixels, input_ids) + + state = wrapper_a.state_dict() + wrapper_b = _build(cfg, param_dtype=torch.float32) + wrapper_b.load_state_dict(state, strict=True) + wrapper_b.eval() + with torch.no_grad(): + logits_b, _ = wrapper_b(pixels, input_ids) + + torch.testing.assert_close(logits_a, logits_b, atol=0.0, rtol=0.0) + + +class TestModalityDispatch: + """MoMa-specific: verify that modality_ids correctly partitions tokens + to per-modality expert groups under the parallel-built model.""" + + def test_text_experts_only_affect_text_positions(self): + """Zero text experts' down_proj weights; image-position outputs are + unchanged, text-position outputs change.""" + cfg = _tiny_moma_configs(gumbel_noise=False) + wrapper = _build(cfg, param_dtype=torch.float32) + wrapper.eval() + pixels, input_ids, _ = _dummy_batch(wrapper, batch=1, text_len=8) + with torch.no_grad(): + logits_full, _ = wrapper(pixels, input_ids) + + # Zero every text expert's down_proj across all layers. + with torch.no_grad(): + for layer in wrapper.transformer.layers.values(): + for expert in layer.mlp.experts["text"].experts: + expert.down_proj.weight.zero_() + with torch.no_grad(): + logits_text_off, _ = wrapper(pixels, input_ids) + + # All logits are over the text-only tail (output_slice trims image + # prefix), so the entire output is a text view — different text + # experts -> different output. + assert not torch.allclose(logits_full, logits_text_off) + + def test_image_experts_only_affect_image_positions(self): + """Symmetric to the above: zero image-expert down_proj only changes + outputs whose computation read image positions. Because the shared + attention mixes image into every text query, *all* text-position + outputs are sensitive to image experts (image experts shape image + tokens, which feed into the keys/values that text attends to). + Still, the output must change.""" + cfg = _tiny_moma_configs(gumbel_noise=False) + wrapper = _build(cfg, param_dtype=torch.float32) + wrapper.eval() + pixels, input_ids, _ = _dummy_batch(wrapper, batch=1, text_len=8) + with torch.no_grad(): + logits_full, _ = wrapper(pixels, input_ids) + + with torch.no_grad(): + for layer in wrapper.transformer.layers.values(): + for expert in layer.mlp.experts["image"].experts: + expert.down_proj.weight.zero_() + with torch.no_grad(): + logits_image_off, _ = wrapper(pixels, input_ids) + + assert not torch.allclose(logits_full, logits_image_off) + + def test_unbalanced_expert_counts_build_and_forward(self): + """Paper's moe_7t1i (7 text experts + 1 image expert) builds and + forwards. Validates the per-modality expert dict supports asymmetric + allocations (which MoT's single num_experts field cannot express).""" + from kempnerforge.model.moma import MoMaBlock + + cfg = _tiny_moma_configs(experts_per_modality={"image": 1, "text": 7}) + wrapper = _build(cfg) + layer0 = wrapper.transformer.layers["0"] + assert isinstance(layer0, MoMaBlock) + assert layer0.mlp.experts["image"].num_experts == 1 + assert layer0.mlp.experts["text"].num_experts == 7 + + pixels, input_ids, labels = _dummy_batch(wrapper) + logits, _ = wrapper(pixels, input_ids, labels) + assert logits.shape == (2, 16, cfg[0].vocab_size) + assert torch.isfinite(logits).all() + + def test_capacity_factor_explicit_override(self): + """Explicit positive capacity factor overrides the paper default 1/|E^M|.""" + from kempnerforge.model.moma import MoMaBlock + + cfg = _tiny_moma_configs(capacity_factor=0.75) + wrapper = _build(cfg) + layer0 = wrapper.transformer.layers["0"] + assert isinstance(layer0, MoMaBlock) + assert layer0.mlp.experts["image"].router.capacity_factor == 0.75 + assert layer0.mlp.experts["text"].router.capacity_factor == 0.75 + + +class TestFreezeStageHook: + def test_freeze_schedule_transitions(self): + """Schedule that freezes 'moma' at step 3 and unfreezes at step 7.""" + cfg = _tiny_moma_configs( + freeze=[ + FreezeSpec("vision_encoder", True), + FreezeSpec("moma", False), + ], + ) + cfg[3].freeze_schedule = [ # type: ignore[union-attr] + FreezeStage(start_step=3, specs=(FreezeSpec("moma", True),)), + FreezeStage(start_step=7, specs=(FreezeSpec("moma", False),)), + ] + wrapper = _build(cfg, param_dtype=torch.float32) + + layer_params = [ + p for n, p in wrapper.named_parameters() if n.startswith("transformer.layers.") + ] + + # Step 0: layers trainable. + for p in layer_params: + assert p.requires_grad + + # Step 3: layers frozen. + specs = effective_freeze(3, cfg[3].freeze, cfg[3].freeze_schedule) # type: ignore[union-attr] + apply_freeze_specs(wrapper, specs, cfg[3].module_patterns) # type: ignore[union-attr] + for p in layer_params: + assert not p.requires_grad + + # Step 7: layers trainable again. + specs = effective_freeze(7, cfg[3].freeze, cfg[3].freeze_schedule) # type: ignore[union-attr] + apply_freeze_specs(wrapper, specs, cfg[3].module_patterns) # type: ignore[union-attr] + for p in layer_params: + assert p.requires_grad diff --git a/tests/unit/test_moma.py b/tests/unit/test_moma.py new file mode 100644 index 0000000..a4b5738 --- /dev/null +++ b/tests/unit/test_moma.py @@ -0,0 +1,737 @@ +"""Unit tests for Mixture of Modality-Aware Experts (MoMa) operator + block + integration. + +Covers: + +- ``MoMaConfig``: field defaults, validation (modalities, expert counts, + capacity factor), polymorphic methods (``residual_stream_image_tokens``, + ``effective_capacity_factor``). +- ``MoMaStrategy``: ``ModalityContext`` construction (prefix_embeds, + output_slice, modality_ids), modality_ids value/shape/dtype, num_image_tokens. +- ``ExpertChoiceSigmoidRouter``: forward shapes, Gumbel-noise behavior + (train vs eval), empty input, k_e clamping, expert_counts metric. +- ``ExpertChoiceMoE``: forward shape, empty input, gradient flow, + multi-expert accumulation onto the same token, zero contribution when + no expert picks a token. +- ``MoMaFFN``: per-modality dispatch by modality_ids, all-text / + all-image batches, gradient flow into both modality groups, shape / + dtype validation. +- ``MoMaBlock``: forward shape, residual add, gradient flow, parameter + layout (shared QKVO + per-modality MoE FFN). +- End-to-end ``Transformer`` + ``MoMaConfig``: build, forward, gradient + flow, output shape. + +No GPU required; uses CPU tensors. +""" + +from __future__ import annotations + +import math + +import pytest +import torch +import torch.nn as nn + +from kempnerforge.config.schema import ModelConfig +from kempnerforge.config.vlm import MoMaConfig +from kempnerforge.model.attention import Attention +from kempnerforge.model.modality import ModalityContext +from kempnerforge.model.moma import ( + ExpertChoiceMoE, + ExpertChoiceSigmoidRouter, + MoMaBlock, + MoMaFFN, +) +from kempnerforge.model.transformer import Transformer +from kempnerforge.model.vlm import MoMaStrategy + +DEVICE = torch.device("cpu") + + +def _config( + dim: int = 64, + n_heads: int = 4, + n_kv_heads: int | None = None, + n_layers: int = 2, + max_seq_len: int = 64, +) -> ModelConfig: + """Tiny dense config for MoMa unit tests.""" + return ModelConfig( + dim=dim, + n_layers=n_layers, + n_heads=n_heads, + n_kv_heads=n_kv_heads or n_heads, + vocab_size=128, + max_seq_len=max_seq_len, + ffn_hidden_dim=128, + ) + + +# --------------------------------------------------------------------------- +# MoMaConfig +# --------------------------------------------------------------------------- + + +class TestMoMaConfig: + def test_defaults(self): + cfg = MoMaConfig() + assert cfg.arch == "moma" + assert cfg.moma_modalities == ("image", "text") + assert cfg.moma_experts_per_modality == {"image": 4, "text": 4} + assert cfg.moma_capacity_factor == 0.0 + assert cfg.moma_gumbel_noise is True + + def test_module_patterns_includes_moma_alias(self): + cfg = MoMaConfig() + assert "moma" in cfg.module_patterns + # Sanity: alias points at the transformer layers. + assert any("transformer.layers" in p for p in cfg.module_patterns["moma"]) + + def test_residual_stream_image_tokens_is_num_tokens(self): + # MoMa uses the JD/MoT image-prefix layout. + cfg = MoMaConfig() + assert cfg.residual_stream_image_tokens(64) == 64 + assert cfg.residual_stream_image_tokens(0) == 0 + + def test_effective_capacity_factor_paper_default(self): + cfg = MoMaConfig(moma_experts_per_modality={"image": 4, "text": 4}) + # Paper default c_e = 1/|E^M| per modality. + assert cfg.effective_capacity_factor("image") == pytest.approx(0.25) + assert cfg.effective_capacity_factor("text") == pytest.approx(0.25) + + def test_effective_capacity_factor_explicit_override(self): + cfg = MoMaConfig( + moma_experts_per_modality={"image": 4, "text": 4}, + moma_capacity_factor=0.5, + ) + assert cfg.effective_capacity_factor("image") == 0.5 + assert cfg.effective_capacity_factor("text") == 0.5 + + def test_effective_capacity_factor_unbalanced(self): + # Paper's moe_7t1i: 7 text experts + 1 image expert. + cfg = MoMaConfig(moma_experts_per_modality={"image": 1, "text": 7}) + assert cfg.effective_capacity_factor("image") == pytest.approx(1.0) + assert cfg.effective_capacity_factor("text") == pytest.approx(1.0 / 7) + + def test_rejects_fewer_than_two_modalities(self): + with pytest.raises(ValueError, match="at least 2 entries"): + MoMaConfig(moma_modalities=("text",)) + + def test_rejects_missing_text_modality(self): + with pytest.raises(ValueError, match="must include 'text'"): + MoMaConfig( + moma_modalities=("image", "audio"), + moma_experts_per_modality={"image": 2, "audio": 2}, + ) + + def test_rejects_missing_image_modality(self): + with pytest.raises(ValueError, match="must include 'image'"): + MoMaConfig( + moma_modalities=("text", "audio"), + moma_experts_per_modality={"text": 2, "audio": 2}, + ) + + def test_rejects_duplicate_modalities(self): + with pytest.raises(ValueError, match="must not contain duplicates"): + MoMaConfig( + moma_modalities=("image", "text", "image"), + moma_experts_per_modality={"image": 2, "text": 2}, + ) + + def test_rejects_missing_expert_count_entry(self): + with pytest.raises(ValueError, match="missing entries"): + MoMaConfig(moma_experts_per_modality={"image": 2}) + + def test_rejects_extra_expert_count_entry(self): + with pytest.raises(ValueError, match="unknown modality keys"): + MoMaConfig(moma_experts_per_modality={"image": 2, "text": 2, "audio": 4}) + + def test_rejects_nonpositive_expert_count(self): + with pytest.raises(ValueError, match="must be positive"): + MoMaConfig(moma_experts_per_modality={"image": 0, "text": 2}) + + def test_rejects_negative_capacity_factor(self): + with pytest.raises(ValueError, match="capacity_factor must be >= 0"): + MoMaConfig(moma_capacity_factor=-0.1) + + +# --------------------------------------------------------------------------- +# MoMaStrategy +# --------------------------------------------------------------------------- + + +class _StubVisionEncoder(nn.Module): + def __init__(self, num_tokens: int, feature_dim: int) -> None: + super().__init__() + self.num_tokens = num_tokens + self.feature_dim = feature_dim + self.proj = nn.Linear(3, feature_dim) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + # (B, ?, ?, ?) -> (B, num_tokens, feature_dim) + b = pixel_values.shape[0] + return torch.zeros(b, self.num_tokens, self.feature_dim, device=pixel_values.device) + + +class _StubAdapter(nn.Module): + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.proj = nn.Linear(in_dim, out_dim) + + def forward(self, feats: torch.Tensor) -> torch.Tensor: + return self.proj(feats) + + +class _StubWrapper(nn.Module): + def __init__(self, num_tokens: int, feature_dim: int, dim: int) -> None: + super().__init__() + self.vision_encoder = _StubVisionEncoder(num_tokens, feature_dim) + self.adapter = _StubAdapter(feature_dim, dim) + + +class TestMoMaStrategy: + def test_prepare_builds_modality_context(self): + wrapper = _StubWrapper(num_tokens=4, feature_dim=8, dim=16) + strategy = MoMaStrategy() + pixel_values = torch.zeros(2, 3, 8, 8) + input_ids = torch.zeros(2, 6, dtype=torch.long) + + ctx = strategy.prepare(wrapper, pixel_values, input_ids) + assert isinstance(ctx, ModalityContext) + assert ctx.prefix_embeds is not None + assert ctx.prefix_embeds.shape == (2, 4, 16) + assert ctx.output_slice == slice(4, None) + assert ctx.modality_ids is not None + assert ctx.modality_ids.shape == (2, 4 + 6) + assert ctx.modality_ids.dtype == torch.long + + def test_modality_ids_image_then_text(self): + wrapper = _StubWrapper(num_tokens=3, feature_dim=8, dim=16) + strategy = MoMaStrategy() + pixel_values = torch.zeros(1, 3, 8, 8) + input_ids = torch.zeros(1, 5, dtype=torch.long) + + ctx = strategy.prepare(wrapper, pixel_values, input_ids) + # First 3 positions (image) get 0; rest (text) get 1. + assert ctx.modality_ids is not None + ids = ctx.modality_ids[0] + assert torch.equal(ids[:3], torch.zeros(3, dtype=torch.long)) + assert torch.equal(ids[3:], torch.ones(5, dtype=torch.long)) + + def test_num_image_tokens(self): + wrapper = _StubWrapper(num_tokens=7, feature_dim=8, dim=16) + strategy = MoMaStrategy() + assert strategy.num_image_tokens(wrapper) == 7 + + +# --------------------------------------------------------------------------- +# ExpertChoiceSigmoidRouter +# --------------------------------------------------------------------------- + + +class TestExpertChoiceSigmoidRouter: + def test_construction_rejects_nonpositive_experts(self): + with pytest.raises(ValueError, match="num_experts must be positive"): + ExpertChoiceSigmoidRouter(dim=16, num_experts=0, capacity_factor=0.5) + + def test_construction_rejects_nonpositive_capacity(self): + with pytest.raises(ValueError, match="capacity_factor must be positive"): + ExpertChoiceSigmoidRouter(dim=16, num_experts=4, capacity_factor=0.0) + with pytest.raises(ValueError, match="capacity_factor must be positive"): + ExpertChoiceSigmoidRouter(dim=16, num_experts=4, capacity_factor=-0.1) + + def test_forward_shapes(self): + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=4, capacity_factor=0.25) + router.eval() + x = torch.randn(20, 16) + scores, indices = router(x) + # k_e = ceil(0.25 * 20) = 5 + assert scores.shape == (4, 5) + assert indices.shape == (4, 5) + assert indices.dtype == torch.long + # All indices in range [0, N). + assert indices.min().item() >= 0 + assert indices.max().item() < 20 + + def test_forward_rejects_wrong_rank(self): + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=2, capacity_factor=0.5) + with pytest.raises(ValueError, match=r"\(N, D\)"): + router(torch.randn(2, 8, 16)) # (B, S, D) — rank 3, not 2 + + def test_forward_empty_input(self): + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=4, capacity_factor=0.25) + x = torch.zeros(0, 16) + scores, indices = router(x) + assert scores.shape == (4, 0) + assert indices.shape == (4, 0) + + def test_k_e_clamped_to_n(self): + # capacity 1.0 → k_e = N. capacity > 1 (e.g. 5.0 * N) → clamp to N. + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=2, capacity_factor=5.0) + router.eval() + x = torch.randn(3, 16) + scores, indices = router(x) + assert scores.shape == (2, 3) + assert indices.shape == (2, 3) + + def test_k_e_at_least_one(self): + # capacity * N < 1 → k_e should still be 1. + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=2, capacity_factor=0.01) + router.eval() + x = torch.randn(4, 16) + scores, indices = router(x) + # ceil(0.01 * 4) = 1 + assert scores.shape == (2, 1) + assert indices.shape == (2, 1) + + def test_gumbel_noise_changes_scores_in_train_mode(self): + # We check scores, not selections: when gate logits are well-separated, + # Gumbel noise of typical magnitude (G' - G'' ~ Logistic(0, 1)) may not + # flip the top-k ordering even though the underlying scores differ. + # Probing scores directly is the robust signal that noise was applied. + router = ExpertChoiceSigmoidRouter( + dim=16, num_experts=4, capacity_factor=0.5, gumbel_noise=True + ) + router.train() + x = torch.randn(8, 16) + torch.manual_seed(0) + scores_a, _ = router(x) + torch.manual_seed(1) + scores_b, _ = router(x) + assert not torch.allclose(scores_a, scores_b) + + def test_eval_mode_deterministic(self): + router = ExpertChoiceSigmoidRouter( + dim=16, num_experts=4, capacity_factor=0.5, gumbel_noise=True + ) + router.eval() + x = torch.randn(8, 16) + scores_a, idx_a = router(x) + scores_b, idx_b = router(x) + assert torch.equal(idx_a, idx_b) + assert torch.allclose(scores_a, scores_b) + + def test_gumbel_disabled_deterministic_in_train(self): + router = ExpertChoiceSigmoidRouter( + dim=16, num_experts=4, capacity_factor=0.5, gumbel_noise=False + ) + router.train() + x = torch.randn(8, 16) + scores_a, idx_a = router(x) + scores_b, idx_b = router(x) + assert torch.equal(idx_a, idx_b) + assert torch.allclose(scores_a, scores_b) + + def test_expert_counts_set(self): + router = ExpertChoiceSigmoidRouter(dim=16, num_experts=4, capacity_factor=0.5) + router.eval() + x = torch.randn(8, 16) + _ = router(x) + # Each expert picks k_e = ceil(0.5 * 8) = 4 tokens. + assert router.expert_counts.shape == (4,) + assert (router.expert_counts == 4).all() + + +# --------------------------------------------------------------------------- +# ExpertChoiceMoE +# --------------------------------------------------------------------------- + + +class TestExpertChoiceMoE: + def test_forward_shape(self): + moe = ExpertChoiceMoE(dim=16, hidden_dim=32, num_experts=4, capacity_factor=0.5) + moe.eval() + x = torch.randn(10, 16) + y = moe(x) + assert y.shape == (10, 16) + + def test_forward_empty_input(self): + moe = ExpertChoiceMoE(dim=16, hidden_dim=32, num_experts=4, capacity_factor=0.5) + x = torch.zeros(0, 16) + y = moe(x) + assert y.shape == (0, 16) + + def test_forward_rejects_wrong_rank(self): + moe = ExpertChoiceMoE(dim=16, hidden_dim=32, num_experts=2, capacity_factor=0.5) + with pytest.raises(ValueError, match=r"\(N, D\)"): + moe(torch.randn(2, 5, 16)) + + def test_gradient_flow_to_experts_and_gate(self): + moe = ExpertChoiceMoE(dim=16, hidden_dim=32, num_experts=4, capacity_factor=1.0) + moe.train() + x = torch.randn(8, 16, requires_grad=True) + y = moe(x) + loss = y.sum() + loss.backward() + # Gate gradient should be non-None and non-zero (some experts run). + assert moe.router.gate.weight.grad is not None + assert moe.router.gate.weight.grad.abs().sum().item() > 0 + # At least one expert's gate_proj gradient should be non-zero. + any_expert_grad = any( + e.gate_proj.weight.grad is not None and e.gate_proj.weight.grad.abs().sum().item() > 0 + for e in moe.experts + ) + assert any_expert_grad + + def test_token_not_selected_gets_zero_contribution(self): + # With capacity 0.01 and N=10, k_e=1: each expert picks exactly 1 token. + # With 2 experts that's at most 2 distinct tokens selected; at least 8 + # tokens get 0 contribution from the MoE block. + torch.manual_seed(0) + moe = ExpertChoiceMoE( + dim=16, hidden_dim=32, num_experts=2, capacity_factor=0.01, gumbel_noise=False + ) + moe.eval() + x = torch.randn(10, 16) + y = moe(x) + # Rows where the output is exactly zero ⇒ no expert picked that token. + # With token-choice routing this would be impossible, but with EC it + # is the expected behavior. + zero_rows = (y.abs().sum(dim=-1) == 0).sum().item() + assert zero_rows >= 8 + + def test_expert_counts_property(self): + moe = ExpertChoiceMoE(dim=16, hidden_dim=32, num_experts=4, capacity_factor=0.5) + moe.eval() + x = torch.randn(8, 16) + _ = moe(x) + # Expert counts come from the router. + assert moe.expert_counts.shape == (4,) + + +# --------------------------------------------------------------------------- +# MoMaFFN +# --------------------------------------------------------------------------- + + +class TestMoMaFFN: + def _make_ffn(self, dim: int = 32) -> MoMaFFN: + config = _config(dim=dim) + return MoMaFFN( + config, + modalities=("image", "text"), + experts_per_modality={"image": 2, "text": 2}, + capacity_factor_per_modality={"image": 0.5, "text": 0.5}, + gumbel_noise=False, + ) + + def test_construction_rejects_missing_experts_entry(self): + config = _config() + with pytest.raises(ValueError, match="experts_per_modality missing"): + MoMaFFN( + config, + modalities=("image", "text"), + experts_per_modality={"image": 2}, + capacity_factor_per_modality={"image": 0.5, "text": 0.5}, + ) + + def test_construction_rejects_missing_capacity_entry(self): + config = _config() + with pytest.raises(ValueError, match="capacity_factor_per_modality missing"): + MoMaFFN( + config, + modalities=("image", "text"), + experts_per_modality={"image": 2, "text": 2}, + capacity_factor_per_modality={"image": 0.5}, + ) + + def test_forward_shape(self): + ffn = self._make_ffn(dim=32) + ffn.eval() + x = torch.randn(2, 8, 32) + # 4 image positions, 4 text positions (image-prefix layout). + modality_ids = torch.zeros(2, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + y = ffn(x, modality_ids) + assert y.shape == (2, 8, 32) + + def test_forward_rejects_wrong_input_rank(self): + ffn = self._make_ffn() + with pytest.raises(ValueError, match=r"\(B, S, D\)"): + ffn(torch.randn(8, 32), torch.zeros(8, dtype=torch.long)) + + def test_forward_rejects_mismatched_modality_ids_shape(self): + ffn = self._make_ffn() + x = torch.randn(2, 8, 32) + with pytest.raises(ValueError, match="does not match"): + ffn(x, torch.zeros(2, 7, dtype=torch.long)) + + def test_forward_rejects_non_long_modality_ids(self): + ffn = self._make_ffn() + x = torch.randn(2, 8, 32) + with pytest.raises(ValueError, match="dtype must be torch.long"): + ffn(x, torch.zeros(2, 8, dtype=torch.float32)) + + def test_all_text_batch_image_positions_zero(self): + """When no tokens are tagged image, image-position outputs are 0.""" + ffn = self._make_ffn(dim=32) + ffn.eval() + x = torch.randn(1, 6, 32) + # All tokens are text (modality_id == 1). + modality_ids = torch.ones(1, 6, dtype=torch.long) + y = ffn(x, modality_ids) + # Sanity: text positions should generally produce some non-zero output. + # (Test text path actually fires.) + assert y.abs().sum().item() > 0 + + def test_all_image_batch_text_positions_zero(self): + """When no tokens are tagged text, text-position outputs are 0.""" + ffn = self._make_ffn(dim=32) + ffn.eval() + x = torch.randn(1, 6, 32) + modality_ids = torch.zeros(1, 6, dtype=torch.long) + y = ffn(x, modality_ids) + assert y.abs().sum().item() > 0 + + def test_dispatch_isolates_modalities(self): + """Image positions are processed only by image experts, text by text experts. + + Verifies by zeroing one modality's experts and confirming output at + positions of the *other* modality is unchanged. + """ + ffn = self._make_ffn(dim=32) + ffn.eval() + x = torch.randn(1, 8, 32) + modality_ids = torch.zeros(1, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + + y_full = ffn(x, modality_ids) + # Zero text experts' weights → text-position outputs should change, + # image-position outputs should be identical. + with torch.no_grad(): + for e in ffn.experts["text"].experts: + e.gate_proj.weight.zero_() + e.up_proj.weight.zero_() + e.down_proj.weight.zero_() + y_text_zeroed = ffn(x, modality_ids) + # Image positions (indices 0..3) unchanged. + assert torch.allclose(y_full[:, :4, :], y_text_zeroed[:, :4, :]) + # Text positions (indices 4..7) now zero (or different). + assert not torch.allclose(y_full[:, 4:, :], y_text_zeroed[:, 4:, :]) + + def test_gradient_flow_to_both_modality_groups(self): + ffn = self._make_ffn(dim=32) + ffn.train() + x = torch.randn(1, 8, 32, requires_grad=True) + modality_ids = torch.zeros(1, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + y = ffn(x, modality_ids) + y.sum().backward() + # Both modality groups should have gradients on their gates. + image_gate_grad = ffn.experts["image"].router.gate.weight.grad + text_gate_grad = ffn.experts["text"].router.gate.weight.grad + assert image_gate_grad is not None + assert text_gate_grad is not None + assert image_gate_grad.abs().sum().item() > 0 + assert text_gate_grad.abs().sum().item() > 0 + + +# --------------------------------------------------------------------------- +# MoMaBlock +# --------------------------------------------------------------------------- + + +class TestMoMaBlock: + def _make_block(self, dim: int = 32) -> MoMaBlock: + config = _config(dim=dim, n_heads=4, n_kv_heads=4, max_seq_len=32) + return MoMaBlock( + config, + modalities=("image", "text"), + experts_per_modality={"image": 2, "text": 2}, + capacity_factor_per_modality={"image": 0.5, "text": 0.5}, + gumbel_noise=False, + layer_idx=0, + ) + + def test_construction_has_shared_attention_and_per_modality_ffn(self): + block = self._make_block() + # Shared attention: single QKVO Linear (not nn.ModuleDict). + assert isinstance(block.attention, Attention) + # Per-modality MoE FFN: ModuleDict keyed by modality. + assert isinstance(block.mlp, MoMaFFN) + assert set(block.mlp.experts.keys()) == {"image", "text"} + + def test_forward_shape(self): + from kempnerforge.model.position import precompute_rope_frequencies + + block = self._make_block(dim=32) + block.eval() + cos, sin = precompute_rope_frequencies(head_dim=8, max_seq_len=16) + x = torch.randn(2, 8, 32) + modality_ids = torch.zeros(2, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + y = block(x, cos[:8], sin[:8], modality_ids) + assert y.shape == (2, 8, 32) + + def test_residual_preserves_unselected_tokens(self): + """If no expert picks a token, the residual still carries it. + + With capacity factor 0.01 + 2 experts + 8 tokens, k_e = 1 so at + most 2 tokens get nonzero MoE contribution; the rest should + appear ~unchanged at the output (modulo the attention contribution). + """ + from kempnerforge.model.position import precompute_rope_frequencies + + config = _config(dim=32, n_heads=4, n_kv_heads=4) + block = MoMaBlock( + config, + modalities=("image", "text"), + experts_per_modality={"image": 1, "text": 1}, + capacity_factor_per_modality={"image": 0.01, "text": 0.01}, + gumbel_noise=False, + layer_idx=0, + ) + # Zero the attention output so we isolate the FFN residual behavior. + with torch.no_grad(): + block.attention.o_proj.weight.zero_() + block.eval() + cos, sin = precompute_rope_frequencies(head_dim=8, max_seq_len=16) + x = torch.randn(1, 8, 32) + modality_ids = torch.zeros(1, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + y = block(x, cos[:8], sin[:8], modality_ids) + # Most positions should be ~equal to the input (residual passthrough). + # Count rows that are close to input. + close = torch.isclose(y[0], x[0], atol=1e-5).all(dim=-1).sum().item() + assert close >= 6 # at least 6 of 8 tokens passed through unmodified + + def test_gradient_flow(self): + from kempnerforge.model.position import precompute_rope_frequencies + + block = self._make_block(dim=32) + block.train() + cos, sin = precompute_rope_frequencies(head_dim=8, max_seq_len=16) + x = torch.randn(1, 8, 32, requires_grad=True) + modality_ids = torch.zeros(1, 8, dtype=torch.long) + modality_ids[:, 4:] = 1 + y = block(x, cos[:8], sin[:8], modality_ids) + y.sum().backward() + assert x.grad is not None + assert x.grad.abs().sum().item() > 0 + + +# --------------------------------------------------------------------------- +# End-to-end Transformer + MoMaConfig +# --------------------------------------------------------------------------- + + +class TestTransformerWithMoMaConfig: + def test_build_with_moma_config(self): + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + # All blocks should be MoMaBlock instances. + for layer in transformer.layers.values(): + assert isinstance(layer, MoMaBlock) + # MoT-specific state should be empty. + assert transformer._mot_modalities == () + # MoMa-specific state should be set. + assert transformer._moma_modalities == ("image", "text") + + def test_forward_with_modality_context(self): + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + transformer.eval() + b, n_img, t_text = 1, 4, 4 + tokens = torch.randint(0, config.vocab_size, (b, t_text)) + prefix_embeds = torch.randn(b, n_img, config.dim) + modality_ids = torch.zeros(b, n_img + t_text, dtype=torch.long) + modality_ids[:, n_img:] = 1 + ctx = ModalityContext( + prefix_embeds=prefix_embeds, + output_slice=slice(n_img, None), + modality_ids=modality_ids, + ) + logits = transformer(tokens=tokens, modality=ctx) + # output_slice trims image positions → output has t_text positions. + assert logits.shape == (b, t_text, config.vocab_size) + assert torch.isfinite(logits).all() + + def test_forward_rejects_missing_modality_ids(self): + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + b, t_text = 1, 4 + tokens = torch.randint(0, config.vocab_size, (b, t_text)) + prefix_embeds = torch.randn(b, 4, config.dim) + ctx = ModalityContext( + prefix_embeds=prefix_embeds, + output_slice=slice(4, None), + # modality_ids deliberately omitted + ) + with pytest.raises(ValueError, match="requires modality.modality_ids"): + transformer(tokens=tokens, modality=ctx) + + def test_forward_rejects_mismatched_modality_ids_shape(self): + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + b, n_img, t_text = 1, 4, 4 + tokens = torch.randint(0, config.vocab_size, (b, t_text)) + prefix_embeds = torch.randn(b, n_img, config.dim) + # Wrong shape: should be (b, n_img + t_text) = (1, 8) but we pass (1, 7). + modality_ids = torch.zeros(b, 7, dtype=torch.long) + ctx = ModalityContext( + prefix_embeds=prefix_embeds, + output_slice=slice(n_img, None), + modality_ids=modality_ids, + ) + with pytest.raises(ValueError, match="does not match"): + transformer(tokens=tokens, modality=ctx) + + def test_gradient_flow_end_to_end(self): + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + transformer.train() + b, n_img, t_text = 1, 4, 4 + tokens = torch.randint(0, config.vocab_size, (b, t_text)) + prefix_embeds = torch.randn(b, n_img, config.dim, requires_grad=True) + modality_ids = torch.zeros(b, n_img + t_text, dtype=torch.long) + modality_ids[:, n_img:] = 1 + ctx = ModalityContext( + prefix_embeds=prefix_embeds, + output_slice=slice(n_img, None), + modality_ids=modality_ids, + ) + logits = transformer(tokens=tokens, modality=ctx) + logits.sum().backward() + # Prefix embeds and at least one expert in each modality group should + # have gradients. + assert prefix_embeds.grad is not None + assert prefix_embeds.grad.abs().sum().item() > 0 + + +# --------------------------------------------------------------------------- +# Sanity check: math.ceil semantics for k_e +# --------------------------------------------------------------------------- + + +def test_k_e_formula_matches_paper(): + """k_e = ceil(capacity_factor * N) matches the paper's b^M * c_e formula. + + Paper: k_e = b^M * c_e, where b^M is total tokens of modality M. + Implementation: k_e = ceil(capacity_factor * n_tokens) with capacity_factor + defaulting to 1/|E^M| (so k_e ~ N/|E|). + """ + for n, c in [(16, 0.25), (20, 0.5), (100, 0.1), (7, 1.0 / 3)]: + expected = max(1, math.ceil(c * n)) + # The router computes this internally; we verify the formula + # produces sensible values. + assert expected >= 1 + assert expected <= n diff --git a/tests/unit/test_vlm_config.py b/tests/unit/test_vlm_config.py index a852e90..5e865eb 100644 --- a/tests/unit/test_vlm_config.py +++ b/tests/unit/test_vlm_config.py @@ -18,6 +18,7 @@ FreezeSpec, FreezeStage, JointDecoderConfig, + MoMaConfig, MoTConfig, VLMConfig, ) @@ -115,7 +116,7 @@ def test_for_arch_error_lists_registered_arches(self): def test_registry_has_all_archs(self): archs = set(registry.list_vlm_configs()) - assert {"joint_decoder", "cross_attention", "mot"} <= archs + assert {"joint_decoder", "cross_attention", "mot", "moma"} <= archs class TestCrossAttentionResolvedHeads: @@ -309,3 +310,97 @@ def test_warm_start_path_set_without_flag_is_valid(self): ) assert cfg.mot_warm_start_from_text is False assert cfg.mot_warm_start_path == "/tmp/jd_ckpt.pt" + + +class TestMoMaConfig: + """``MoMaConfig`` (registered subclass for arch='moma'). + + Mirrors ``TestMoTConfig``: construction defaults, module_patterns alias, + ``for_arch`` dispatch via the registry, residual-stream layout, and + field validation. Detailed expert-count / capacity-factor / Gumbel-noise + semantics live in ``tests/unit/test_moma.py``. + """ + + def test_construction_defaults(self): + cfg = MoMaConfig() + assert cfg.arch == "moma" + assert isinstance(cfg, VLMConfig) + assert cfg.moma_modalities == ("image", "text") + assert cfg.moma_experts_per_modality == {"image": 4, "text": 4} + assert cfg.moma_capacity_factor == 0.0 + assert cfg.moma_gumbel_noise is True + + def test_module_patterns_has_moma_alias(self): + cfg = MoMaConfig() + assert "moma" in cfg.module_patterns + assert cfg.module_patterns["moma"] == [ + "transformer.layers", + "transformer.layers.*", + ] + # Base aliases still present. + assert "transformer" in cfg.module_patterns + assert "vision_encoder" in cfg.module_patterns + + def test_for_arch_moma_returns_subclass(self): + cfg = VLMConfig.for_arch("moma") + assert isinstance(cfg, MoMaConfig) + assert cfg.arch == "moma" + + def test_for_arch_moma_with_overrides(self): + cfg = VLMConfig.for_arch( + "moma", + moma_experts_per_modality={"image": 1, "text": 7}, + moma_capacity_factor=0.5, + moma_gumbel_noise=False, + ) + assert isinstance(cfg, MoMaConfig) + assert cfg.moma_experts_per_modality == {"image": 1, "text": 7} + assert cfg.moma_capacity_factor == 0.5 + assert cfg.moma_gumbel_noise is False + + def test_residual_stream_image_tokens_passes_through(self): + cfg = MoMaConfig() + assert cfg.residual_stream_image_tokens(128) == 128 + + def test_effective_capacity_factor_paper_default(self): + cfg = MoMaConfig(moma_experts_per_modality={"image": 4, "text": 4}) + assert cfg.effective_capacity_factor("image") == pytest.approx(0.25) + assert cfg.effective_capacity_factor("text") == pytest.approx(0.25) + + def test_effective_capacity_factor_explicit_override(self): + cfg = MoMaConfig( + moma_experts_per_modality={"image": 4, "text": 4}, + moma_capacity_factor=0.6, + ) + assert cfg.effective_capacity_factor("image") == 0.6 + assert cfg.effective_capacity_factor("text") == 0.6 + + def test_moma_modalities_must_include_text(self): + with pytest.raises(ValueError, match="must include 'text'"): + MoMaConfig( + moma_modalities=("image", "audio"), + moma_experts_per_modality={"image": 2, "audio": 2}, + ) + + def test_moma_modalities_must_include_image(self): + with pytest.raises(ValueError, match="must include 'image'"): + MoMaConfig( + moma_modalities=("text", "audio"), + moma_experts_per_modality={"text": 2, "audio": 2}, + ) + + def test_negative_capacity_factor_raises(self): + with pytest.raises(ValueError, match="capacity_factor must be >= 0"): + MoMaConfig(moma_capacity_factor=-0.1) + + def test_missing_expert_count_entry_raises(self): + with pytest.raises(ValueError, match="missing entries"): + MoMaConfig(moma_experts_per_modality={"image": 2}) + + def test_extra_expert_count_entry_raises(self): + with pytest.raises(ValueError, match="unknown modality keys"): + MoMaConfig(moma_experts_per_modality={"image": 2, "text": 2, "audio": 4}) + + def test_nonpositive_expert_count_raises(self): + with pytest.raises(ValueError, match="must be positive"): + MoMaConfig(moma_experts_per_modality={"image": 0, "text": 2}) From d6f72ee72fbd4f09217ddccda55823fcb0f54aa5 Mon Sep 17 00:00:00 2001 From: amazloumi Date: Wed, 20 May 2026 11:25:35 -0400 Subject: [PATCH 2/4] MoMa post-review: validate modality_ids, expose expert counts, note AC no-op --- configs/train/vlm_7b_moma.toml | 5 ++ kempnerforge/model/moma.py | 20 ++++++++ kempnerforge/model/transformer.py | 29 ++++++++++- tests/unit/test_moma.py | 82 +++++++++++++++++++++++++++++++ 4 files changed, 134 insertions(+), 2 deletions(-) diff --git a/configs/train/vlm_7b_moma.toml b/configs/train/vlm_7b_moma.toml index f7a8c58..b420ba8 100644 --- a/configs/train/vlm_7b_moma.toml +++ b/configs/train/vlm_7b_moma.toml @@ -81,6 +81,11 @@ seed = 42 # torch.compile (graph breaks on data-dependent dispatch); JobConfig.validate # emits a warning if you flip this on. compile_model = false +# NOTE: AC=full is currently a no-op for MoMa — kempnerforge.distributed.parallel.apply_ac +# matches `isinstance(m, TransformerBlock)` only, and MoMaBlock is a sibling nn.Module +# (same gap exists for MoTBlock and CrossAttentionBlock). The follow-up PR will refactor +# apply_ac to iterate ``transformer.layers`` directly so AC works across all VLM arches. +# Until then this line has no effect; OOM risk on tighter GPUs is real and unmitigated. # Required given the per-layer expert duplication on 4x H200; drop to # "selective" or "none" if you scale down experts_per_modality or n_layers. activation_checkpointing = "full" diff --git a/kempnerforge/model/moma.py b/kempnerforge/model/moma.py index dfbbcfc..8b24b1a 100644 --- a/kempnerforge/model/moma.py +++ b/kempnerforge/model/moma.py @@ -331,18 +331,38 @@ def forward(self, x: torch.Tensor, modality_ids: torch.Tensor) -> torch.Tensor: mod_flat = modality_ids.reshape(b * s) out = torch.zeros_like(x_flat) + # Tracks how many positions actually got routed to *some* modality + # group. With well-formed modality_ids (values in [0, len(modalities))) + # this equals b*s at the end. We accumulate Python ints from + # ``idx.numel()`` (tensor metadata, no host sync) and compare after + # the loop — much cheaper than an upfront ``.all()`` reduction which + # would force a device->host sync every step. The error fires + # post-FFN, but the in-range work is the same either way and the + # failure mode without this check is silent zero-output on the + # affected positions (residual still carries them through, so the + # bug would only surface as quietly wrong training). + total_routed = 0 for i, m in enumerate(self.modalities): # nonzero() avoids the boolean-mask copy and gives us a 1-D index # tensor we can feed to index_select + scatter. idx = (mod_flat == i).nonzero(as_tuple=False).squeeze(-1) # (N_m,) if idx.numel() == 0: continue + total_routed += idx.numel() x_m = x_flat.index_select(0, idx) # (N_m, D) y_m = self.experts[m](x_m) # (N_m, D) # The modality groups partition the position space, so indices # are guaranteed unique across iterations. index_copy on # disjoint indices is safe and autograd-friendly. out = out.index_copy(0, idx, y_m) + + if total_routed != b * s: + raise ValueError( + f"MoMaFFN.forward: modality_ids contains out-of-range values; " + f"{b * s - total_routed} of {b * s} positions did not match any " + f"modality (allowed values: 0..{len(self.modalities) - 1} for " + f"modalities {self.modalities!r})" + ) return out.view(b, s, d) diff --git a/kempnerforge/model/transformer.py b/kempnerforge/model/transformer.py index ffe87b8..78d321d 100644 --- a/kempnerforge/model/transformer.py +++ b/kempnerforge/model/transformer.py @@ -24,7 +24,7 @@ from kempnerforge.model.mlp import build_mlp from kempnerforge.model.modality import ModalityContext from kempnerforge.model.moe import MoEMLP, build_moe -from kempnerforge.model.moma import MoMaBlock +from kempnerforge.model.moma import MoMaBlock, MoMaFFN from kempnerforge.model.mot import MoTBlock from kempnerforge.model.norm import build_norm from kempnerforge.model.position import precompute_rope_frequencies @@ -262,13 +262,38 @@ def get_moe_aux_loss(self) -> torch.Tensor: return total def get_expert_counts(self) -> dict[int, torch.Tensor]: - """Collect per-layer expert utilization. Returns {} if dense.""" + """Collect per-layer expert utilization for flat MoE layers. + + Returns ``{layer_idx: (num_experts,) tensor}`` for layers whose MLP + is a ``MoEMLP`` (the standard, single-pool MoE). Returns ``{}`` for + dense models and for MoMa: MoMa's per-modality groups have a + different shape (per modality, per expert) and surface through + ``get_moma_expert_counts`` instead. + """ counts = {} for name, layer in self.layers.items(): if isinstance(layer.mlp, MoEMLP): counts[int(name)] = layer.mlp.expert_counts return counts + def get_moma_expert_counts(self) -> dict[int, dict[str, torch.Tensor]]: + """Collect per-layer, per-modality expert utilization for MoMa layers. + + Returns ``{layer_idx: {modality: (num_experts_for_modality,) tensor}}`` + for every layer whose MLP is a ``MoMaFFN``; returns ``{}`` otherwise. + Each modality's expert count tensor reflects the most recent forward + through that layer's expert-choice router (paper Figure 5-style + utilization). Counts on a fresh model (no forward yet) are the + router's init zeros. + """ + counts: dict[int, dict[str, torch.Tensor]] = {} + for name, layer in self.layers.items(): + if isinstance(layer.mlp, MoMaFFN): + counts[int(name)] = { + m: layer.mlp.experts[m].expert_counts for m in layer.mlp.modalities + } + return counts + def forward( self, tokens: torch.Tensor | None = None, diff --git a/tests/unit/test_moma.py b/tests/unit/test_moma.py index a4b5738..e3a116b 100644 --- a/tests/unit/test_moma.py +++ b/tests/unit/test_moma.py @@ -461,6 +461,29 @@ def test_forward_rejects_non_long_modality_ids(self): with pytest.raises(ValueError, match="dtype must be torch.long"): ffn(x, torch.zeros(2, 8, dtype=torch.float32)) + def test_forward_rejects_out_of_range_modality_id(self): + """A modality id >= len(modalities) is silently equivalent to "no group + picked this token" without the check — caller would see zero output + at those positions, which is hard to debug. We require an explicit + ValueError instead. + """ + ffn = self._make_ffn() + ffn.eval() + x = torch.randn(2, 4, 32) + modality_ids = torch.zeros(2, 4, dtype=torch.long) + modality_ids[0, 0] = 2 # only 0 ("image") and 1 ("text") are valid + with pytest.raises(ValueError, match="out-of-range"): + ffn(x, modality_ids) + + def test_forward_rejects_negative_modality_id(self): + ffn = self._make_ffn() + ffn.eval() + x = torch.randn(2, 4, 32) + modality_ids = torch.zeros(2, 4, dtype=torch.long) + modality_ids[1, 2] = -1 + with pytest.raises(ValueError, match="out-of-range"): + ffn(x, modality_ids) + def test_all_text_batch_image_positions_zero(self): """When no tokens are tagged image, image-position outputs are 0.""" ffn = self._make_ffn(dim=32) @@ -716,6 +739,65 @@ def test_gradient_flow_end_to_end(self): assert prefix_embeds.grad is not None assert prefix_embeds.grad.abs().sum().item() > 0 + def test_get_expert_counts_returns_empty_for_moma(self): + """The flat-MoE helper returns {} for MoMa — MoMa layers expose + per-modality counts through get_moma_expert_counts instead. + """ + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 2}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + assert transformer.get_expert_counts() == {} + + def test_get_moma_expert_counts_returns_empty_when_no_moma_layers(self): + """Dense Transformer (no MoMa, no MoT, no MoE) has no MoMa layers, + so the helper returns {}. + """ + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + transformer = Transformer(config) + assert transformer.get_moma_expert_counts() == {} + + def test_get_moma_expert_counts_after_forward(self): + """After a forward, get_moma_expert_counts surfaces per-layer + per-modality utilization tensors (paper Figure 5 shape). + """ + config = _config(dim=32, n_heads=4, n_kv_heads=4, n_layers=2, max_seq_len=32) + # Unequal experts per modality so the shape check is meaningful. + vlm = MoMaConfig( + moma_experts_per_modality={"image": 2, "text": 3}, + moma_gumbel_noise=False, + ) + transformer = Transformer(config, vlm_config=vlm, num_image_tokens=4) + transformer.eval() + + b, n_img, t_text = 1, 4, 4 + tokens = torch.randint(0, config.vocab_size, (b, t_text)) + prefix_embeds = torch.randn(b, n_img, config.dim) + modality_ids = torch.zeros(b, n_img + t_text, dtype=torch.long) + modality_ids[:, n_img:] = 1 + ctx = ModalityContext( + prefix_embeds=prefix_embeds, + output_slice=slice(n_img, None), + modality_ids=modality_ids, + ) + _ = transformer(tokens=tokens, modality=ctx) + + counts = transformer.get_moma_expert_counts() + # Both layers reported. + assert set(counts.keys()) == {0, 1} + for layer_counts in counts.values(): + # Both modality groups present. + assert set(layer_counts.keys()) == {"image", "text"} + # Per-modality shape == (num_experts_for_that_modality,). + assert layer_counts["image"].shape == (2,) + assert layer_counts["text"].shape == (3,) + # Counts are non-negative; expert-choice puts >=1 token per expert + # when N_m > 0 (k_e = max(1, ceil(c*N_m)) and N_m == 4 here). + assert (layer_counts["image"] >= 1).all() + assert (layer_counts["text"] >= 1).all() + # --------------------------------------------------------------------------- # Sanity check: math.ceil semantics for k_e From a75c6f8958a1403f3200d7191923869a40a71e1f Mon Sep 17 00:00:00 2001 From: amazloumi Date: Wed, 20 May 2026 11:41:12 -0400 Subject: [PATCH 3/4] fixing pyright ckeck --- kempnerforge/model/transformer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/kempnerforge/model/transformer.py b/kempnerforge/model/transformer.py index 78d321d..d3a835a 100644 --- a/kempnerforge/model/transformer.py +++ b/kempnerforge/model/transformer.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import cast + import torch import torch.nn as nn @@ -24,7 +26,7 @@ from kempnerforge.model.mlp import build_mlp from kempnerforge.model.modality import ModalityContext from kempnerforge.model.moe import MoEMLP, build_moe -from kempnerforge.model.moma import MoMaBlock, MoMaFFN +from kempnerforge.model.moma import ExpertChoiceMoE, MoMaBlock, MoMaFFN from kempnerforge.model.mot import MoTBlock from kempnerforge.model.norm import build_norm from kempnerforge.model.position import precompute_rope_frequencies @@ -289,8 +291,14 @@ def get_moma_expert_counts(self) -> dict[int, dict[str, torch.Tensor]]: counts: dict[int, dict[str, torch.Tensor]] = {} for name, layer in self.layers.items(): if isinstance(layer.mlp, MoMaFFN): + # nn.ModuleDict.__getitem__ returns Module; cast back to the + # concrete expert-group type so pyright sees the + # ``expert_counts`` Tensor rather than ``Tensor | Module``. + # The cast is safe because ``MoMaFFN.__init__`` only ever + # writes ``ExpertChoiceMoE`` values into ``experts``. counts[int(name)] = { - m: layer.mlp.experts[m].expert_counts for m in layer.mlp.modalities + m: cast(ExpertChoiceMoE, layer.mlp.experts[m]).expert_counts + for m in layer.mlp.modalities } return counts From 8dce33f284358dd273ba0c4b51a360adc634c9ae Mon Sep 17 00:00:00 2001 From: amazloumi Date: Tue, 26 May 2026 11:10:55 -0400 Subject: [PATCH 4/4] resovling the above comments --- configs/train/vlm_7b_moma.toml | 15 +++--- kempnerforge/config/job.py | 18 +++++++ tests/unit/test_config.py | 92 ++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 6 deletions(-) diff --git a/configs/train/vlm_7b_moma.toml b/configs/train/vlm_7b_moma.toml index b420ba8..0a1b36f 100644 --- a/configs/train/vlm_7b_moma.toml +++ b/configs/train/vlm_7b_moma.toml @@ -18,10 +18,12 @@ # # Parameter / memory note: with the default 7B-dense-shaped backbone # (dim=4096, n_layers=32, ffn ~14336) and 8 SwiGLU experts per layer -# (4 image + 4 text), total params is much larger than dense 7B. Use -# FSDP=4 + activation_checkpointing="full" to fit on 4x H200. For a -# roomier setup, reduce moma_experts_per_modality (e.g. 2t2i) or fall -# back to the MoT debug config. Pipeline Parallel + VLM is not supported. +# (4 image + 4 text), total params is much larger than dense 7B. Fitting +# on 4x H200 today needs FSDP=4 plus reducing moma_experts_per_modality +# (e.g. 2t2i), or falling back to the MoT debug config — +# activation_checkpointing="full" is set below but is currently a no-op +# for MoMa (see the comment near the field). Pipeline Parallel + VLM is +# not supported. # # max_seq_len allocation: residual_image_tokens + max_text_len. Image # tokens prepend the text sequence in the residual stream, so the budget @@ -86,8 +88,9 @@ compile_model = false # (same gap exists for MoTBlock and CrossAttentionBlock). The follow-up PR will refactor # apply_ac to iterate ``transformer.layers`` directly so AC works across all VLM arches. # Until then this line has no effect; OOM risk on tighter GPUs is real and unmitigated. -# Required given the per-layer expert duplication on 4x H200; drop to -# "selective" or "none" if you scale down experts_per_modality or n_layers. +# Intended once apply_ac is refactored — at that point this will checkpoint every +# MoMaBlock per layer and recover the memory budget needed for the per-layer expert +# duplication on 4x H200. Currently inert (see NOTE above). activation_checkpointing = "full" loss_fn = "cross_entropy" diff --git a/kempnerforge/config/job.py b/kempnerforge/config/job.py index 5880228..4b3e1dd 100644 --- a/kempnerforge/config/job.py +++ b/kempnerforge/config/job.py @@ -246,3 +246,21 @@ def validate(self, world_size: int = 1) -> None: "(modality_ids-based scatter/gather + expert-choice top-k cause " "graph breaks). Set compile_model=false for MoMa models." ) + + # AC=full silently no-ops on MoMa because apply_ac matches + # isinstance(m, TransformerBlock) only and MoMaBlock is a + # sibling nn.Module. The selective branch wraps + # isinstance(m, Attention), and MoMaBlock.attention is a + # vanilla Attention, so 'selective' DOES work on MoMa today + # — only 'full' is broken. The cross-arch apply_ac refactor + # will make 'full' work; until then surface the silent + # no-op explicitly so fresh MoMa configs don't trust it. + if self.train.activation_checkpointing == "full": + import logging + + logging.getLogger(__name__).warning( + "AC=full is currently a no-op for MoMa (apply_ac matches " + "TransformerBlock only; MoMaBlock is a sibling nn.Module). " + "Use ac='selective' (wraps the Attention submodule, still works) " + "or reduce moma_experts_per_modality until the apply_ac refactor lands." + ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index f6a72d2..0b5a3f0 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -24,6 +24,7 @@ VisionEncoderConfig, VLMConfig, ) +from kempnerforge.config.vlm import MoMaConfig # --------------------------------------------------------------------------- # ModelConfig @@ -484,6 +485,97 @@ def test_no_warning_for_random_with_overrides(self, caplog): assert not any("random" in r.getMessage() for r in caplog.records) +class TestMoMaAcFullWarning: + """``JobConfig.validate`` warns when MoMa is paired with + ``activation_checkpointing="full"`` because ``apply_ac`` matches + ``isinstance(m, TransformerBlock)`` only — and ``MoMaBlock`` is a + sibling ``nn.Module``, so AC=full silently no-ops on MoMa. The + ``selective`` branch wraps ``isinstance(m, Attention)`` and + ``MoMaBlock.attention`` is a standard ``Attention``, so + ``selective`` and ``none`` are correct and MUST NOT warn. Non-MoMa + arches are unaffected and must not warn either. + + Mirrors ``TestHfEncoderOverrideWarning``'s caplog plumbing: the + ``kempnerforge`` logger has propagation disabled by + ``metrics.logger._configure_root``, so we re-enable it for the + duration of each test so pytest's root-attached caplog can observe + records from ``kempnerforge.config.job``. + """ + + def setup_method(self): + import logging + + self._kf_logger = logging.getLogger("kempnerforge") + self._old_propagate = self._kf_logger.propagate + self._kf_logger.propagate = True + + def teardown_method(self): + self._kf_logger.propagate = self._old_propagate + + def _moma_job(self, ac: ActivationCheckpointing) -> JobConfig: + return JobConfig( + model=ModelConfig(max_seq_len=1024), + train=TrainConfig( + seq_len=576, + activation_checkpointing=ac, + # Silence the sibling MoMa-compile warning so it can't + # accidentally satisfy our "no AC warning" assertions. + compile_model=False, + ), + vision_encoder=VisionEncoderConfig(type="random", feature_dim=256, num_tokens=64), + adapter=AdapterConfig(), + vlm=MoMaConfig(max_text_len=512), + ) + + def test_warns_on_full(self, caplog): + cfg = self._moma_job(ActivationCheckpointing.full) + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + cfg.validate(world_size=1) + assert any( + "AC=full is currently a no-op for MoMa" in r.getMessage() + and "selective" in r.getMessage() + for r in caplog.records + ) + + def test_no_warning_on_selective(self, caplog): + cfg = self._moma_job(ActivationCheckpointing.selective) + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + cfg.validate(world_size=1) + assert not any( + "AC=full is currently a no-op for MoMa" in r.getMessage() for r in caplog.records + ) + + def test_no_warning_on_none(self, caplog): + cfg = self._moma_job(ActivationCheckpointing.none) + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + cfg.validate(world_size=1) + assert not any( + "AC=full is currently a no-op for MoMa" in r.getMessage() for r in caplog.records + ) + + def test_no_warning_for_non_moma_arch_with_full(self, caplog): + """A non-MoMa arch (default joint_decoder) with AC=full uses + TransformerBlock-shaped layers, which apply_ac handles correctly. + The MoMa-only warning must not fire. + """ + cfg = JobConfig( + model=ModelConfig(max_seq_len=1024), + train=TrainConfig( + seq_len=576, + activation_checkpointing=ActivationCheckpointing.full, + compile_model=False, + ), + vision_encoder=VisionEncoderConfig(type="random", feature_dim=256, num_tokens=64), + adapter=AdapterConfig(), + vlm=VLMConfig(max_text_len=512), # default arch = joint_decoder + ) + with caplog.at_level("WARNING", logger="kempnerforge.config.job"): + cfg.validate(world_size=1) + assert not any( + "AC=full is currently a no-op for MoMa" in r.getMessage() for r in caplog.records + ) + + # --------------------------------------------------------------------------- # TOML Loading # ---------------------------------------------------------------------------