Skip to content

Broadcast-based allgather in host for-loop#5925

Draft
Priya2698 wants to merge 5 commits intomainfrom
pm/stream_broadcast
Draft

Broadcast-based allgather in host for-loop#5925
Priya2698 wants to merge 5 commits intomainfrom
pm/stream_broadcast

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit 5ee8365

Description

  • Introduce StreamBroadcast communication type for broadcast-based allgather

  • Add host loop index parameter to communication conversion functions

  • Implement ring allgather detection from DIDx to Stream parallel types

  • Add test coverage for column-parallel linear forward with StreamBroadcast

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Implement StreamBroadcast communication and ring allgather detection

csrc/host_ir/lower_to_communication.cpp

  • Add new lowerToStreamBroadcast function for StreamBroadcast
    communication
  • Modify getCommunicationInfo to detect ring allgather patterns
    (DIDx->Stream)
  • Update convertSingleOpToCommunication to handle StreamBroadcast with
    host loop index
  • Add StreamBroadcast to communication layout compliance checks
  • +61/-5   
    lowering.cpp
    Integrate host loop index into communication lowering       

    csrc/host_ir/lowering.cpp

  • Pass innermost loop index as host_loop_index to communication
    conversion
  • Skip sharding checks for StreamBroadcast communications
  • +4/-2     
    convert_op_to_communication.cpp
    Update function call with new parameter                                   

    csrc/host_ir/pass/convert_op_to_communication.cpp

  • Update convertSingleOpToCommunication call to include host_loop_index
    parameter
  • +4/-1     
    communication.cpp
    Add StreamBroadcast support to communication infrastructure

    csrc/multidevice/communication.cpp

  • Add StreamBroadcast to CommunicationType enum string representation
  • Include StreamBroadcast in hasRoot() and isReduction() functions
  • Handle StreamBroadcast in postSingleCommunication using broadcast
    logic
  • +6/-0     
    lower_to_communication.h
    Update function signature for host loop index support       

    csrc/host_ir/lower_to_communication.h

  • Add host_loop_index parameter to convertSingleOpToCommunication
    function signature
  • +1/-0     
    communication.h
    Add StreamBroadcast to communication type enum                     

    csrc/multidevice/communication.h

  • Add StreamBroadcast to CommunicationType enum
  • Update documentation for StreamBroadcast communication type
  • +6/-1     
    Formatting
    ops.cpp
    Minor formatting improvement in error message                       

    csrc/host_ir/ops.cpp

    • Improve error message formatting in shardByStream function
    +1/-1     
    Tests
    test_overlap.py
    Add comprehensive test coverage for StreamBroadcast functionality

    tests/python/multidevice/test_overlap.py

  • Add column_parallel_linear_forward function for testing
    StreamBroadcast
  • Add test_column_parallel_linear_forward test with MPI and profiler
    checks
  • Add test_column_parallel_linear_forward_benchmark for performance
    testing
  • +119/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Mesh Validation Logic

    The StreamBroadcast implementation includes mesh validation requiring sender and receiver meshes to be identical. This is checked with NVF_ERROR_EQ but the error message could be more descriptive about the implications.

    NVF_ERROR_EQ(
        sender_mesh,
        receiver_mesh,
        "StreamBroadcast sender and receiver meshes must be the same. Given ",
        sender_mesh,
        " and ",
        receiver_mesh);
    Stream ID Detection Logic

    The logic for detecting when to use StreamBroadcast involves checking if c_stream_logical_id equals p2c_map.at(p_logical_id). This condition should be documented more clearly as it determines the transition from DIDx -> Stream parallel types.

    if (c_stream_logical_id == p2c_map.at(p_logical_id)) {
      NVF_CHECK(
          same_mesh,
          "Broadcast based allgather in stream parallel requires same "
          "mesh.")
      fill_communication_info(
          CommunicationType::StreamBroadcast,
          p_logical_id,
          c_stream_logical_id);
      continue;
    }
    Test Dimension Requirements

    The test has specific requirements that (h * 4) % d == 0 and t % d == 0. These constraints should be documented in comments to help future maintainers understand why these specific divisibility requirements exist.

    if (h * 4) % d != 0:
        pytest.skip(
            f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
        )
    if t % d != 0:
        pytest.skip(
            f"Column-parallel linear requires {t} to be divisible by world size {d}."
        )

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant