Skip to content

Replay swizzle1D#5924

Open
Priya2698 wants to merge 2 commits intomainfrom
pm/swizzle1d_replay
Open

Replay swizzle1D#5924
Priya2698 wants to merge 2 commits intomainfrom
pm/swizzle1d_replay

Conversation

@Priya2698
Copy link
Collaborator

Enables replaying swizzle1D as required by shardByStream

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit 987672b

Description

  • Add support for replaying Swizzle1D transforms in multi-device sharding

  • Update transform traversal to handle Swizzle1D alongside existing Split transforms

  • Implement Swizzle1D replay in ReplaySelf class for proper transform propagation

  • Modify test to use hir::shardByStream which leverages the new Swizzle1D replay capability

Changes walkthrough

Relevant files
Enhancement
allocation_utils.cpp
Add Swizzle1D support in allocation sharding                         

csrc/multidevice/allocation_utils.cpp

  • Add handling for Swizzle1D transforms in shard allocation loop
  • Update transform processing to support both Swizzle1D and Split
    operations
  • Add error handling for unsupported transform types
  • +17/-10 
    propagation.cpp
    Support Swizzle1D in loop domain transforms                           

    csrc/multidevice/propagation.cpp

  • Extend transform loop domain handling to include Swizzle1D transforms
  • Add replay logic for Swizzle1D parallel types in sharding propagation
  • Maintain existing Split transform handling with new conditional
    structure
  • +56/-41 
    transform_iter.cpp
    Add Swizzle1D to transform replay dispatch                             

    csrc/transform_iter.cpp

  • Update dispatch method to include Swizzle1D in supported expression
    types
  • Add handle method for Swizzle1D that redirects to ReplaySelf
  • Improve error messaging for unsupported transform expressions
  • +10/-2   
    transform_replay.cpp
    Implement Swizzle1D replay in ReplaySelf                                 

    csrc/transform_replay.cpp

  • Implement Swizzle1D replay functionality in ReplaySelf class
  • Add transform mapping for Swizzle1D input to output domains
  • Ensure proper loop ID tracking for Swizzle1D operations
  • +25/-0   
    dispatch.h
    Add Swizzle1D to dispatch macros                                                 

    csrc/dispatch.h

    • Add Swizzle1D to dispatch macro for expression type support
    +1/-0     
    transform_iter.h
    Declare Swizzle1D handler method                                                 

    csrc/transform_iter.h

  • Add declaration for Swizzle1D handler method in ReplayTransformations
    class
  • +2/-0     
    Tests
    test_multidevice_host_ir.cpp
    Update test to use shardByStream with Swizzle1D                   

    tests/cpp/test_multidevice_host_ir.cpp

  • Replace manual sharding operations with hir::shardByStream calls
  • Simplify test setup by leveraging new Swizzle1D replay capability
  • Remove redundant tensor creation and parallelization setup
  • +5/-19   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error handling robustness

    The new code throws NVF_THROW for unexpected transform types, but the old code used NVF_ERROR. Consider if this change in error handling behavior is intentional and consistent with the project's error handling patterns.

    NVF_THROW("Expected a swizzle1d or split transform. Got: ", e);
    Transform validation completeness

    The new code handles both Swizzle1D and Split transforms, but the error message suggests only split or swizzle1d transforms should appear. Verify that the transform traversal logic correctly handles all possible transform types that could appear in the dependency chain.

    NVF_THROW("Expected a split or swizzle1d transform. Got: ", transform);
    API consistency

    The ReplayTransformations::handle(Swizzle1D*) method throws an error indicating to use ReplaySelf instead, but ReplaySelf is in a different file. Consider if this API design is clear enough for users or if additional documentation/guidance would be helpful.

    void ReplayTransformations::handle(Swizzle1D* swizzle1d) {
      NVF_THROW(
          "Swizzle1D replay not supported in ReplayTransformations, use ReplaySelf "
          "instead: ",
          swizzle1d->toString());
    }

    Test failures

    • (Medium, 3) Shape mismatch in thunder.tests.test_update_aliases higher-order inplace alias update (thunderfx path)

      Test Name A100 GB200 H100 Source
      thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32
    • (Medium, 1) Large scalar mismatch in Thunder nanoGPT autograd nvFuser CUDA test (test_networks)

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @Priya2698 Priya2698 marked this pull request as ready for review February 6, 2026 18:16
    @Priya2698 Priya2698 requested a review from wujingyue February 6, 2026 18:16
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 6, 2026

    Greptile Overview

    Greptile Summary

    This PR enables replay support for Swizzle1D transforms, which is required by the shardByStream functionality in multi-device operations.

    Key Changes:

    • Added Swizzle1D to the dispatch macro list in dispatch.h to enable proper expression handling
    • Implemented Swizzle1D replay logic in ReplaySelf class within transform_replay.cpp, which creates a replayed iteration domain while preserving the parallel type
    • Added a handler in ReplayTransformations that throws an error directing users to use ReplaySelf instead of the base replay class
    • Extended transformLoopDomain in propagation.cpp to handle Swizzle1D transforms during domain propagation, properly mapping reference to target IDs
    • Updated shardAllocationAsLoop in allocation_utils.cpp to process Swizzle1D transforms by preserving contiguity through the swizzle operation
    • Simplified the test in test_multidevice_host_ir.cpp by using the hir::shardByStream helper function instead of manually creating shard tensors

    The implementation follows the same pattern as existing transform types (Split, Merge, Resize), with proper error handling and domain mapping. The test validates that swizzle1d replay works correctly with stream parallelization.

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The implementation is straightforward and follows established patterns for other transform types (Split, Merge, Resize). All changes are consistent with the codebase architecture, proper error handling is in place, and the test validates the functionality. The code is well-structured with clear separation between base replay (ReplayTransformations) and self-replay (ReplaySelf) classes.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/transform_replay.cpp Implemented Swizzle1D replay logic in ReplaySelf class, creating replayed iter domain with preserved parallel type
    csrc/multidevice/allocation_utils.cpp Extended shardAllocationAsLoop to handle Swizzle1D transforms by preserving contiguity through the swizzle operation
    csrc/multidevice/propagation.cpp Added Swizzle1D handling in transformLoopDomain, replaying swizzle1d transforms with parallel type preservation during domain propagation
    tests/cpp/test_multidevice_host_ir.cpp Simplified test by using hir::shardByStream helper instead of manually creating shard tensors, validates swizzle1d replay functionality

    Sequence Diagram

    sequenceDiagram
        participant Client
        participant HostIR as Host IR Container
        participant Shard as shardByStream
        participant Replay as ReplaySelf
        participant PropLoop as transformLoopDomain
        participant AllocUtil as shardAllocationAsLoop
        
        Client->>HostIR: Create tensors with swizzle1d transforms
        Note over HostIR: tv->outer_split(1, d)<br/>tv->axis(1)->parallelize(DIDx)<br/>tv->outer_split(0, d)<br/>tv->swizzle1d(0, DIDx)<br/>tv->axis(0)->parallelize(Stream)
        
        Client->>Shard: shardByStream(tv, stream_index, expr)
        Shard->>PropLoop: transformLoopDomain()
        Note over PropLoop: Get transforms between domains
        
        PropLoop->>PropLoop: Process Split transforms
        Note over PropLoop: Replay splits on target domain
        
        PropLoop->>PropLoop: Process Swizzle1D transforms
        Note over PropLoop: Create replayed_id = IterDomain::swizzle1d()<br/>Preserve parallelType<br/>Update ref2target mapping
        
        PropLoop-->>Shard: Return transformed loop domain
        
        Shard->>AllocUtil: shardAllocationAsLoop()
        Note over AllocUtil: Process allocation transforms
        
        AllocUtil->>AllocUtil: Handle Swizzle1D in allocation
        Note over AllocUtil: Preserve contiguity through swizzle<br/>allocation_to_contiguity.insert()
        
        AllocUtil->>AllocUtil: Handle Split in allocation
        Note over AllocUtil: Split contiguity for inner/outer
        
        AllocUtil-->>Shard: Return updated allocation domain
        
        Shard->>Replay: Replay transforms using ReplaySelf
        Note over Replay: handle(Swizzle1D*)<br/>Map input to output ID<br/>Create replayed swizzle1d<br/>Update loop_ids_
        
        Replay-->>Shard: Return replayed shard
        Shard-->>Client: Return sharded tensor view
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines 96 to 97
    if (e->isA<Swizzle1D>()) {
    auto* swizzle1d = e->as<Swizzle1D>();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    if (e->isA<Swizzle1D>()) {
    auto* swizzle1d = e->as<Swizzle1D>();
    if (auto* swizzle_1d = dynamic_cast<Swizzle1D*>(e)) {

    Comment on lines 101 to 102
    } else if (e->isA<Split>()) {
    auto* split = e->as<Split>();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    ditto

    allocation_to_contiguity.insert(
    split_i, split->inner(), inner_contiguity);
    } else {
    NVF_THROW("Expected a swizzle1d or split transform. Got: ", e);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    for (...) {
      if (a) {
        ...
      } else if (b) {
        ...
      } else {
        ...
      }
    }
    

    =>

    for (...) {
      if (a) {
        ...
        continue;
      }
    
      if (b) {
        ...
        continue;
      }
    
      ...
    }
    

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What is the reason for preferring if .. continue over if-else blocks?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The same reason why I prefer early exit.

    1. Once I see continue, I know the rest of the for loop has nothing to do with condition a. Otherwise, I have to scroll down to check where the end-if is.
    2. Less indentation for the final else block.

    if (hasRootToLogicalTransform(consumer_id, consumer_tv)) {
    validate_split(split, target_id);
    }
    if (transform->isA<Swizzle1D>()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    ditto

    auto is_supported_expr =
    e->isOneOf<Split, Merge, Swizzle, Swizzle1D, Resize>();
    NVF_ERROR(
    is_supported_expr, "Invalid expr type found in transform traversal.");
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    is_supported_expr, "Invalid expr type found in transform traversal.");
    (e->isOneOf<Split, Merge, Swizzle, Swizzle1D, Resize>()));

    The extra pair of parentheses is needed for disambiguate the commas.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants