Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ KempnerForge supports VLM training via a thin wrapper around the existing `Trans
- **Joint-Decoder** (`arch = "joint_decoder"`): image embeds are prepended to the text sequence; the transformer runs over the concatenated `(image, text)` sequence and the LM head is applied to text positions only.
- **Cross-Attention** (`arch = "cross_attention"`, Llama-3-V style): the residual stream carries text only. Separate `CrossAttentionBlock`s inserted at a configurable cadence let text queries attend to image K/V. CA blocks are zero-initialized so adding the arch on top of a text-only checkpoint is identity at step 0 and learns from there.
- **Mixture-of-Transformers** (`arch = "mot"`, Liang et al. 2024 Algorithm 1): every layer carries per-modality Q/K/V/O projections plus a per-modality FFN; a single global self-attention mixes all modality streams. Image tokens prepend the text sequence (same residual layout as Joint-Decoder); per-modality residual projections are zero-initialized so a fresh MoT block is identity at construction. A warm-start helper (`mot_warm_start_from_text_stack`) translates a JD or text-only checkpoint into per-modality copies — toggle via `[model.vlm].mot_warm_start_from_text` + `mot_warm_start_path`.
- **Mixture of Modality-Aware Experts** (`arch = "moma"`, Lin et al. 2024 arXiv:2407.21770): one shared set of Q/K/V/O projections feeding a global self-attention, plus per-modality MoE FFN groups (paper's optimal default 4 image + 4 text experts per layer). Tokens route deterministically to their modality group (level-1, reusing the same `modality_ids` mechanism MoT uses) and then through a learned expert-choice + Sigmoid router within the group (level-2, with Gumbel-Sigmoid noise during training, paper Eq. 5). Image tokens prepend the text sequence (same residual layout as JD/MoT). v1 supports training only — expert-choice routing is non-causal, so autoregressive generation requires auxiliary routers (paper §2.4) which are deferred to a follow-up.

```bash
# 1-GPU smoke (random encoder, Joint-Decoder)
Expand All @@ -160,9 +161,12 @@ uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_cross_a

# 4-GPU 7B Mixture-of-Transformers
uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_mot.toml

# 4-GPU 7B Mixture of Modality-Aware Experts (4 text + 4 image experts per layer)
uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml
```

Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); multi-image and video are reserved slots for follow-up work.
Configs set `[model.vlm]` with `arch`, the encoder registry key, the number of image tokens, and a freeze list (`FreezeSpec`). For Cross-Attention, set `cross_attention_every_n_layers` and optionally `cross_attention_n_kv_heads` (0 → MHA; positive → GQA on the cross path). For MoT, set `mot_modalities` (must include both `"image"` and `"text"`); `mot_image_n_heads` / `mot_image_n_kv_heads` are forward-looking per-modality head fields (v1 enforces equality with the text-side counts since the operator runs a single global SDPA). For MoMa, set `moma_experts_per_modality = {image = N, text = M}` as a nested TOML table (the paper's optimal balanced default is `4t4i`; unbalanced allocations like `{image = 1, text = 7}` match the paper's `moe_7t1i` ablation), and optionally `moma_capacity_factor` (defaults to `1/|E^M|` per modality — the paper's perfect-balance setting) and `moma_gumbel_noise` (`true` by default for paper-faithful EC routing). `model.num_experts` must be `0` when `arch = "moma"`; the per-modality counts supersede it, and JobConfig.validate rejects the combination. The vision encoder stays in its HF-loaded dtype; the transformer, adapter, and CA / MoT / MoMa blocks are cast to `param_dtype`. Pipeline Parallel + VLM is not supported on this branch (raises at startup); MoMa + Expert Parallelism is also rejected in v1. Multi-image and video are reserved slots for follow-up work.

**Adding a new VLM arch.** The discriminated-union dispatch is registry-driven, so a new arch is four small additions, no edits to existing call sites:

Expand Down
133 changes: 133 additions & 0 deletions configs/train/vlm_7b_moma.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# 7B VLM Mixture of Modality-Aware Experts (MoMa) on 4x H200 (141 GB/GPU).
#
# Mode: smoke / starter.
#
# Shared Q/K/V/O attention (one set across modalities) + per-modality MoE
# FFN groups with expert-choice + Sigmoid routing (Lin et al. 2024,
# arXiv:2407.21770). Default 4 text + 4 image experts per layer
# (paper's optimal moe_4t4i). Tokens route deterministically by modality
# (level 1) then through a learned EC + Sigmoid router within their
# modality group (level 2). Image tokens prepend the text sequence
# (same residual layout as Joint-Decoder / MoT); output_slice trims
# them off the LM head input.
#
# Inference note: MoMa v1 supports training only. Expert-choice routing
# is non-causal (each expert's top-k depends on all tokens in the batch);
# autoregressive generation requires auxiliary routers (paper §2.4),
# deferred to a follow-up.
#
# Parameter / memory note: with the default 7B-dense-shaped backbone
# (dim=4096, n_layers=32, ffn ~14336) and 8 SwiGLU experts per layer
# (4 image + 4 text), total params is much larger than dense 7B. Fitting
# on 4x H200 today needs FSDP=4 plus reducing moma_experts_per_modality
# (e.g. 2t2i), or falling back to the MoT debug config —
# activation_checkpointing="full" is set below but is currently a no-op
# for MoMa (see the comment near the field). Pipeline Parallel + VLM is
# not supported.
#
# max_seq_len allocation: residual_image_tokens + max_text_len. Image
# tokens prepend the text sequence in the residual stream, so the budget
# must cover both modalities.
#
# Usage:
# uv run torchrun --nproc_per_node=4 scripts/train.py configs/train/vlm_7b_moma.toml
#
# Default points at a 30-sample COCO val substitute (sayakpaul/coco-30-val-2014)
# so a fresh clone runs without external setup. For real training override:
# --data.hf_dataset_name=/path/to/local/hf_dataset \
# --data.hf_dataset_text_field=caption
# Swap the encoder via [vision_encoder].type = "siglip2" / "clip".

[model]
dim = 4096
n_layers = 32
n_heads = 32
n_kv_heads = 8
vocab_size = 50257
ffn_dim_multiplier = 1.3
norm_type = "rmsnorm"
activation = "silu"
max_seq_len = 1024 # 256 image + 512 text (max_text_len default) + headroom
rope_theta = 500000.0
tie_embeddings = false

[vision_encoder]
type = "random"
feature_dim = 1024
num_tokens = 256

[vlm]
arch = "moma"
# Paper Eq. 5: Gumbel-Sigmoid noise on router scores during training.
# Set false for a deterministic forward (useful for warm-start parity
# checks and reproducibility-sensitive smoke runs).
moma_gumbel_noise = true
# moma_capacity_factor = 0.0 → use paper default 1/|E^M| per modality
# (each expert sees the average token load; perfect EC balance).

[vlm.moma_experts_per_modality]
# Paper's optimal balanced allocation (moe_4t4i at 1.4B-compute-matched
# in Table 1). Unbalanced allocations like {image = 1, text = 7} are
# supported and match the paper's moe_7t1i / 1t7i ablations.
image = 4
text = 4

[train]
batch_size = 4
seq_len = 768
max_steps = 200
grad_accum_steps = 1
grad_clip_norm = 1.0
seed = 42
# Modality-aware scatter/gather + EC top-k are not yet validated under
# torch.compile (graph breaks on data-dependent dispatch); JobConfig.validate
# emits a warning if you flip this on.
compile_model = false
# NOTE: AC=full is currently a no-op for MoMa — kempnerforge.distributed.parallel.apply_ac
# matches `isinstance(m, TransformerBlock)` only, and MoMaBlock is a sibling nn.Module
# (same gap exists for MoTBlock and CrossAttentionBlock). The follow-up PR will refactor
# apply_ac to iterate ``transformer.layers`` directly so AC works across all VLM arches.
# Until then this line has no effect; OOM risk on tighter GPUs is real and unmitigated.
# Intended once apply_ac is refactored — at that point this will checkpoint every
# MoMaBlock per layer and recover the memory budget needed for the per-layer expert
# duplication on 4x H200. Currently inert (see NOTE above).
activation_checkpointing = "full"
loss_fn = "cross_entropy"

[optimizer]
name = "adamw"
lr = 1e-4
weight_decay = 0.1
betas = [0.9, 0.95]
eps = 1e-8
fused = true

[scheduler]
name = "cosine"
warmup_steps = 5
min_lr_ratio = 0.1

[data]
hf_dataset_name = "sayakpaul/coco-30-val-2014"
hf_dataset_split = "train"
hf_dataset_image_field = "image"
hf_dataset_text_field = "caption"
hf_image_size = 224
tokenizer_path = "gpt2"
num_workers = 2
pin_memory = true
prefetch_factor = 2

[distributed]
dp_shard = -1
nccl_timeout_sec = 600

[checkpoint]
dir = "checkpoints/vlm_7b_moma"
interval = 1000
keep_last_n = 1

[metrics]
log_interval = 1
enable_wandb = false
enable_tensorboard = false
90 changes: 90 additions & 0 deletions configs/train/vlm_debug_moma.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# VLM smoke config — Mixture of Modality-Aware Experts (MoMa) arch,
# tiny LLM + random vision encoder.
#
# Mode: smoke.
#
# Runs end-to-end in <2 minutes on 1 GPU. Uses RandomVisionEncoder so no
# HF download is needed. Shared Q/K/V/O self-attention + per-modality
# MoE FFN groups with expert-choice + Sigmoid routing (Lin et al. 2024,
# arXiv:2407.21770). 2 text experts + 2 image experts at every layer.
#
# max_seq_len allocation: residual_image_tokens + max_text_len. Image
# tokens prepend the text sequence in the residual stream (same layout
# as Joint-Decoder / MoT).
#
# Note: MoMa v1 supports training only. Expert-choice routing is
# non-causal; autoregressive generation requires auxiliary routers
# (paper §2.4), deferred to a follow-up.
#
# Usage:
# uv run python scripts/train.py configs/train/vlm_debug_moma.toml \
# --data.hf_dataset_name=... --data.tokenizer_path=gpt2

[model]
dim = 256
n_layers = 4
n_heads = 4
n_kv_heads = 4
vocab_size = 50257 # gpt2 vocab
max_seq_len = 576 # 64 image + 512 text
norm_type = "rmsnorm"
activation = "silu"

[vision_encoder]
type = "random"
feature_dim = 384
num_tokens = 64

[vlm]
arch = "moma"

[vlm.moma_experts_per_modality]
image = 2
text = 2

[train]
batch_size = 2
seq_len = 576
max_steps = 50
grad_accum_steps = 1
grad_clip_norm = 1.0
seed = 42
compile_model = false
activation_checkpointing = "none"

[optimizer]
name = "adamw"
lr = 3e-4
weight_decay = 0.1
betas = [0.9, 0.95]
fused = false

[scheduler]
name = "cosine"
warmup_steps = 5
min_lr_ratio = 0.1

[data]
# 30-sample COCO val substitute (sayakpaul/coco-30-val-2014). For a real
# training run, override via CLI: --data.hf_dataset_name=<your-dataset>
hf_dataset_name = "sayakpaul/coco-30-val-2014"
hf_dataset_split = "train"
hf_dataset_image_field = "image"
hf_dataset_text_field = "caption"
hf_image_size = 224
tokenizer_path = "gpt2"
num_workers = 2
pin_memory = true

[distributed]
dp_shard = -1

[checkpoint]
dir = "checkpoints/vlm_debug_moma"
interval = 25
keep_last_n = 2

[metrics]
log_interval = 5
enable_wandb = false
enable_tensorboard = false
45 changes: 44 additions & 1 deletion kempnerforge/config/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kempnerforge.config.scheduler import SchedulerConfig
from kempnerforge.config.training import TrainConfig
from kempnerforge.config.vision import VisionEncoderConfig
from kempnerforge.config.vlm import VLMConfig
from kempnerforge.config.vlm import MoMaConfig, VLMConfig

# Vision-encoder types whose builders load a HuggingFace model and probe
# feature_dim / num_tokens from the model's config. Setting these knobs
Expand Down Expand Up @@ -221,3 +221,46 @@ def validate(self, world_size: int = 1) -> None:
# the build + forward + backward path. CrossAttentionBlocks
# themselves remain dense MLP; MoE lives in the text
# TransformerBlocks (where moe_frequency selects).

if isinstance(self.vlm, MoMaConfig):
# MoMa carries its own per-modality expert counts on
# ``moma_experts_per_modality``; ``model.num_experts``
# would be a redundant second source of truth, so we
# reject it explicitly rather than silently ignoring.
if self.model.num_experts > 0:
raise ValueError(
"MoMa + model.num_experts > 0 is rejected. MoMa derives expert "
"counts per modality from vlm.moma_experts_per_modality; set "
"model.num_experts=0 (the dense-LLM default) when using arch='moma'."
)
if self.distributed.ep > 1:
raise ValueError(
"MoMa + Expert Parallelism is not supported in v1. Per-modality "
"expert groups need EP-aware dispatch that is not yet wired."
)
if self.train.compile_model:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MoMa branch of JobConfig.validate (job.py:225-247) already warns on compile_model (lines 241-247). The AC=full no-op fits the same precedent but is silent: a user authoring a fresh MoMa config without reading vlm_7b_moma.toml's NOTE block gets no warning that activation checkpointing has no effect.

Add a sibling warning right after the compile block (import logging is already in scope there):

  if self.train.activation_checkpointing == "full":
      logging.getLogger(__name__).warning(
          "AC=full is currently a no-op for MoMa (apply_ac matches "
          "TransformerBlock only; MoMaBlock is a sibling nn.Module). "
          "Use ac='selective' (wraps the Attention submodule, still works) "
          "or reduce moma_experts_per_modality until the apply_ac refactor lands."
      )

Trigger only on "full", not != "none": apply_ac's selective branch (parallel.py:120) uses isinstance(m, Attention), and MoMaBlock.attention is a standard Attention (moma.py:401), so selective mode does wrap MoMa correctly. Only full is broken.

Could ride on the upcoming apply_ac refactor PR, or land as a small patch here.

import logging

logging.getLogger(__name__).warning(
"torch.compile is not yet validated for MoMa dispatch "
"(modality_ids-based scatter/gather + expert-choice top-k cause "
"graph breaks). Set compile_model=false for MoMa models."
)

# AC=full silently no-ops on MoMa because apply_ac matches
# isinstance(m, TransformerBlock) only and MoMaBlock is a
# sibling nn.Module. The selective branch wraps
# isinstance(m, Attention), and MoMaBlock.attention is a
# vanilla Attention, so 'selective' DOES work on MoMa today
# — only 'full' is broken. The cross-arch apply_ac refactor
# will make 'full' work; until then surface the silent
# no-op explicitly so fresh MoMa configs don't trust it.
if self.train.activation_checkpointing == "full":
import logging

logging.getLogger(__name__).warning(
"AC=full is currently a no-op for MoMa (apply_ac matches "
"TransformerBlock only; MoMaBlock is a sibling nn.Module). "
"Use ac='selective' (wraps the Attention submodule, still works) "
"or reduce moma_experts_per_modality until the apply_ac refactor lands."
)
Loading
Loading