feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357
Open
aoshen524 wants to merge 3 commits intoalibaba:mainfrom
Open
feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks#357aoshen524 wants to merge 3 commits intoalibaba:mainfrom
aoshen524 wants to merge 3 commits intoalibaba:mainfrom
Conversation
…es SP ranks Distribute whole images across Ulysses SP ranks for parallelized ViT computation, reducing ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x ViT memory reduction). Key changes: - Add roll/utils/context_parallel/vision_dp.py with image distribution utilities, GatherVisionEmbeddings autograd function, and model-agnostic VisionTransformer wrapper - Add apply_vision_dp_patch() in monkey_patch.py for Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE VisionTransformer classes - Integrate into DeepSpeed strategy (both inference and training workers) - Add 17 unit tests covering all utility functions, edge cases, and integration workflows Ported from verl (verl-project/verl#5230). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Integrate upstream hf_flash_attention_patch for transformers>=4.53.0 alongside Vision DP patches. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…issues Address reviewer comments (same fixes as verl PR #5230 and AReaL PR #929): 1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate partial sequence gradients before slicing. Fixes silent gradient loss when vision tokens span multiple sequence shard boundaries. 2. **Load-balanced assignment**: Replace count-based chunking with greedy contiguous bin-packing that balances total patch load across ranks. 3. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from caller instead of doing all_gather in forward. 4. **Idempotency guard**: Extract `_patch_vision_class()` helper with `_vision_dp_patched` attribute check. Add `_unapply_vision_class()` to properly clear the flag on unapply. 5. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe blocks from apply/unapply (not yet in transformers vl_model_mappings). 6. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward entry point to avoid repeated `.tolist()` GPU→CPU syncs. 7. **Tensor slicing**: Replace Python loop + list append in prepare_local_vision_inputs with contiguous tensor slice using cumsum. 8. **Test improvements**: Rename tests, add load balancing test, add gather_none_group test, use parametrize. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Vision Data Parallel: Distribute ViT computation across Ulysses SP ranks
Ported from verl PR #5230, adapted for ROLL's Ulysses SP infrastructure.
Motivation
When using Ulysses Sequence Parallelism (sp_size > 1), the VisionTransformer still processes ALL images on every rank, wasting memory. Vision DP distributes whole images across SP ranks, reducing ViT memory by ~sp_size x.
Key changes
roll/utils/context_parallel/vision_dp.pyroll/utils/context_parallel/monkey_patch.pytests/utils/test_vision_dp_on_cpu.pyFixes applied (addressing reviewer feedback)
grad_scaler * dp_sizewithall_reduce(SUM)before slicing — fixes silent gradient loss when vision tokens span sequence shard boundariesall_countslocally from replicatedgrid_thw_patch_vision_class()/_unapply_vision_class()helpers with_vision_dp_patchedflagvl_model_mappingsgrid_thw.cpu()at entry pointcumsum+ contiguous slice replaces Python loopsQwen3-VL support
Handles Qwen3-VL's tuple return
(embeddings, deepstack_embeddings)— all-gathers both the main embeddings and deepstack outputs.Tests
python -m pytest tests/utils/test_vision_dp_on_cpu.py -v # 21 passed