Skip to content

3D sharding for triangle updates#5890

Draft
wujingyue wants to merge 11 commits intomainfrom
wjy/sp
Draft

3D sharding for triangle updates#5890
wujingyue wants to merge 11 commits intomainfrom
wjy/sp

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 29, 2026

I got the "triangle updates incoming" test passing in this PR. Below are the key issues identified, workarounds and their current status:

1. Sharding Propagation Rework

2. Multi-Dimensional Sharding & getCommunicationInfo

  • Status: Implementation in progress: 2cfcd99
  • Details: The commit updates getCommunicationInfo to support multi-dimensional sharding. It reuses haveDifferentShardings to identify inconsistencies between input and output TensorView objects. The commit needs cleanup and further test verification to be merged.
  • Technical Debt: Per Extend IdModel to map DIDs for certain patterns. #3987, haveDifferentShardings is currently bottlenecked by the expensive ExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.

3. Misaligned Memory Access in Transpose Kernels

4. Performance Bottleneck: AllGather memory

  • Area: Communication Optimization
  • Details: The current naive AllGather preceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.
  • Proposed Fix: We need to implement stream-parallelization to enable:
    • Ring-based AllGather (with Swizzle), or
    • Broadcast-based communication (without Swizzle). AFAICT, fast broadcast requires multicasting and therefore symmetric memory.

cc @DejunL

@github-actions
Copy link

github-actions bot commented Jan 29, 2026

Review updated until commit 08a5f45

Description

  • Add 3D sharding support for AlphaFold3 triangle updates with multi-device testing

  • Improve error handling in communication operations with better validation

  • Add debug logging for fusion transforms and segmented fusion output

  • Fix device dimension handling in vectorization helper calculations

  • Implement transpose operation tests with 3D device mesh

Changes walkthrough

Relevant files
Tests
test_alphafold3.py
AlphaFold3 triangle updates with 3D sharding                         

tests/python/multidevice/test_alphafold3.py

  • Add comprehensive AlphaFold3 model building blocks including
    layer_norm and gating functions
  • Implement triangle updates test with incoming/outgoing direction
    variants
  • Configure 3D device mesh sharding with cp_size=2 and dp_size
    partitioning
  • Add tensor sharding and execution validation for multi-GPU scenarios
  • +225/-0 
    test_multidevice.py
    Multi-device transpose operation testing                                 

    tests/python/multidevice/test_multidevice.py

  • Add transpose operation test with 3D device mesh configuration
  • Implement complex tensor reshaping and allocation domain settings
  • Test multi-device transpose with cp_size=2 partitioning
  • +58/-1   
    Enhancement
    lower_to_communication.cpp
    Enhanced communication error handling and validation         

    csrc/host_ir/lower_to_communication.cpp

  • Replace NVF_ERROR with NVF_ERROR_EQ for consistent error checking
  • Add input/output size validation in getCommunicationInfo function
  • Implement haveDifferentShardings check for producer-consumer pairs
  • Add device mesh comparison logic and improved error messages
  • +24/-18 
    propagate_shardings.cpp
    Debug logging for sharding propagation transforms               

    csrc/preseg_passes/propagate_shardings.cpp

  • Add debug dump logging for fusion transforms after pass execution
  • Include PreSegmenterLogging option for transform debugging
  • +7/-0     
    fusion_segmenter.cpp
    Segmented fusion debug output formatting                                 

    csrc/fusion_segmenter.cpp

  • Replace printMath() with print() method for fusion output
  • Improve formatting with proper newline handling in debug output
  • +4/-3     
    Bug fix
    vectorize_helper.cpp
    Device dimension handling in vectorization                             

    csrc/scheduler/vectorize_helper.cpp

  • Add special handling for device dimensions in sharded extent
    calculation
  • Prevent division by device count for device dimension IterDomains
  • Set sharded extent to 1 for device dimensions
  • +7/-2     
    Miscellaneous
    base.h
    Header cleanup                                                                                     

    csrc/base.h

    • Remove unnecessary #include header
    +0/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Handling Robustness

    The changes add stricter error checking with NVF_ERROR_EQ and NVF_THROW, but the new error handling in getCommunicationInfo could be too aggressive. The code now throws errors when "Not sharded on this parallel type" which might break legitimate use cases where some parallel types aren't used. This needs validation against existing test suites.

    NVF_THROW("Not sharded on this parallel type: ", pt);
    Sharding Validation Logic

    The new logic using haveDifferentShardings() to filter parallel types before processing could introduce subtle bugs. The previous code continued processing all parallel types, while the new code skips non-different shardings entirely. This change in behavior should be thoroughly tested with various sharding configurations.

    if (!haveDifferentShardings(producer, consumer, {pt})) {
      continue;
    }
    Test Coverage Completeness

    While the AlphaFold3 test is comprehensive, it only tests successful execution without validating correctness against a reference implementation. The test should include torch.testing.assert_close() comparisons to ensure the 3D sharding produces mathematically correct results, especially given the complexity of triangle updates and the mentioned transpose kernel issues.

    (z_out,) = fd.execute(
        [
            z_in,
            w_norm_in,
            b_norm_in,
            w_p_in,
            w_g_in,
            w_norm_out,
            b_norm_out,
            w_p_out,
            w_g_out,
            mask,
        ]
    )
    assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z)

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant