Skip to content

Comments

feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357

Open
aoshen524 wants to merge 3 commits intoalibaba:mainfrom
aoshen524:feat/vision-dp-ulysses
Open

feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357
aoshen524 wants to merge 3 commits intoalibaba:mainfrom
aoshen524:feat/vision-dp-ulysses

Conversation

@aoshen524
Copy link

@aoshen524 aoshen524 commented Feb 16, 2026

Vision Data Parallel: Distribute ViT computation across Ulysses SP ranks

Ported from verl PR #5230, adapted for ROLL's Ulysses SP infrastructure.

Motivation

When using Ulysses Sequence Parallelism (sp_size > 1), the VisionTransformer still processes ALL images on every rank, wasting memory. Vision DP distributes whole images across SP ranks, reducing ViT memory by ~sp_size x.

Key changes

File Change
roll/utils/context_parallel/vision_dp.py Core utilities: load-balanced assignment, tensor slicing, all-gather with gradient fix
roll/utils/context_parallel/monkey_patch.py Integration with idempotency guard, clean unapply support
tests/utils/test_vision_dp_on_cpu.py CPU-only unit tests (21 tests)

Fixes applied (addressing reviewer feedback)

  1. Gradient routing (critical): Replace grad_scaler * dp_size with all_reduce(SUM) before slicing — fixes silent gradient loss when vision tokens span sequence shard boundaries
  2. Load-balanced assignment: Greedy contiguous bin-packing by patch count (not image count)
  3. Remove unnecessary all_gather: Compute all_counts locally from replicated grid_thw
  4. Idempotency guard: _patch_vision_class() / _unapply_vision_class() helpers with _vision_dp_patched flag
  5. Remove Qwen3-VL-MoE dead code: Not yet in transformers vl_model_mappings
  6. GPU→CPU sync optimization: Single grid_thw.cpu() at entry point
  7. Tensor slicing: cumsum + contiguous slice replaces Python loops
  8. Test improvements: Parametrize, rename, add load balancing + gather tests

Qwen3-VL support

Handles Qwen3-VL's tuple return (embeddings, deepstack_embeddings) — all-gathers both the main embeddings and deepstack outputs.

Tests

python -m pytest tests/utils/test_vision_dp_on_cpu.py -v
# 21 passed

…es SP ranks

Distribute whole images across Ulysses SP ranks for parallelized ViT computation,
reducing ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x ViT memory reduction).

Key changes:
- Add roll/utils/context_parallel/vision_dp.py with image distribution utilities,
  GatherVisionEmbeddings autograd function, and model-agnostic VisionTransformer wrapper
- Add apply_vision_dp_patch() in monkey_patch.py for Qwen2-VL, Qwen2.5-VL, Qwen3-VL,
  Qwen3-VL-MoE VisionTransformer classes
- Integrate into DeepSpeed strategy (both inference and training workers)
- Add 17 unit tests covering all utility functions, edge cases, and integration workflows

Ported from verl (verl-project/verl#5230).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@CLAassistant
Copy link

CLAassistant commented Feb 16, 2026

CLA assistant check
All committers have signed the CLA.

aoshen524 and others added 2 commits February 16, 2026 07:11
Integrate upstream hf_flash_attention_patch for transformers>=4.53.0
alongside Vision DP patches.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…issues

Address reviewer comments (same fixes as verl PR #5230 and AReaL PR #929):

1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with
   `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate
   partial sequence gradients before slicing. Fixes silent gradient loss
   when vision tokens span multiple sequence shard boundaries.

2. **Load-balanced assignment**: Replace count-based chunking with greedy
   contiguous bin-packing that balances total patch load across ranks.

3. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from
   caller instead of doing all_gather in forward.

4. **Idempotency guard**: Extract `_patch_vision_class()` helper with
   `_vision_dp_patched` attribute check. Add `_unapply_vision_class()` to
   properly clear the flag on unapply.

5. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe
   blocks from apply/unapply (not yet in transformers vl_model_mappings).

6. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward
   entry point to avoid repeated `.tolist()` GPU→CPU syncs.

7. **Tensor slicing**: Replace Python loop + list append in
   prepare_local_vision_inputs with contiguous tensor slice using cumsum.

8. **Test improvements**: Rename tests, add load balancing test, add
   gather_none_group test, use parametrize.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants