Skip to content

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 6, 2026

Description

THD Sink attention is supported in 9.18.0

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Greptile Summary

This PR enables THD (total-heads-dimension) format support for sink attention (off-by-one and learnable softmax types) when using cuDNN 9.18.0 or higher. Previously, the combination of THD format with sink attention was unconditionally disabled.

Key changes:

  • Modified utils.py to conditionally disable FusedAttention for THD + sink attention only when cuDNN < 9.18.0, allowing the feature for newer versions
  • Updated context_parallel.py to add version-gated assertion that permits THD + sink attention for cuDNN >= 9.18.0
  • Fixed f-string formatting in several assertion messages (added missing f prefixes)
  • Added comprehensive unit test test_dpa_softmax_thd that validates sink attention with THD format on cuDNN 9.18.0+

The changes are well-aligned with cuDNN's feature support timeline and maintain backward compatibility by preserving the restriction for older cuDNN versions.

Confidence Score: 5/5

  • This PR is safe to merge with no identified issues
  • The changes correctly enable a new feature (THD sink attention) for cuDNN >= 9.18.0 while maintaining backward compatibility. The version checks are consistent across all files, f-string bugs are fixed, and comprehensive tests are added. The logic aligns with existing patterns in the codebase.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Conditionally disables FusedAttention for THD format with sink attention only for cuDNN < 9.18.0, enabling the feature for newer versions
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Adds version check to allow THD format with sink attention for cuDNN >= 9.18.0, and fixes f-string formatting in assertion messages
tests/pytorch/attention/test_attention.py Adds new test for sink attention with THD format, requiring cuDNN 9.18.0+

Sequence Diagram

sequenceDiagram
    participant Test as test_dpa_softmax_thd
    participant DPA as DotProductAttention
    participant Utils as get_attention_backend
    participant CP as attn_forward_func_with_cp
    participant cuDNN as cuDNN Backend

    Test->>DPA: Call with qkv_format="thd_thd_thd"<br/>softmax_type="sink"
    DPA->>Utils: Check backend availability
    Utils->>Utils: Check cudnn_version >= (9, 18, 0)
    alt cuDNN >= 9.18.0
        Utils->>Utils: Keep use_fused_attention=True
        Utils-->>DPA: FusedAttention backend enabled
    else cuDNN < 9.18.0
        Utils->>Utils: Set use_fused_attention=False
        Utils-->>DPA: FusedAttention disabled
    end
    
    alt Context Parallelism enabled
        DPA->>CP: attn_forward_func_with_cp
        CP->>CP: Check cudnn_version >= (9, 18, 0)
        alt cuDNN >= 9.18.0
            CP->>CP: Allow softmax_type != "vanilla"<br/>with qkv_format="thd"
            CP->>cuDNN: Execute attention with sink
        else cuDNN < 9.18.0
            CP->>CP: Assert fails for sink + THD
            CP-->>DPA: Error: Not supported
        end
    end
    
    cuDNN-->>DPA: Attention output
    DPA-->>Test: Test result
Loading

@KshitijLakhani
Copy link
Collaborator

/te-ci pytorch

@KshitijLakhani KshitijLakhani requested a review from pggPL January 7, 2026 22:29
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables THD (Total Hidden Dimension) format support for sink attention (non-vanilla softmax types like "off-by-one" and "learnable") when using cuDNN version 9.18.0 or higher.

Key changes:

  • Replaced blanket disablement of FusedAttention for THD + sink attention with version-gated logic
  • Added version check in context parallelism to allow THD + sink attention on cuDNN >= 9.18.0
  • Fixed f-string formatting bugs in assertion messages ({cp_comm_type=}, {softmax_type=} were not being interpolated)
  • Added comprehensive test coverage for THD format with various softmax types

Technical details:

  • Previously, THD format with sink attention was completely disabled for both FusedAttention and UnfusedDotProductAttention backends
  • With cuDNN 9.18.0+, the limitation is lifted, allowing FusedAttention to work with this configuration
  • The change properly preserves UnfusedDotProductAttention availability for THD + sink attention (as per the support matrix in comments)

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-scoped, properly version-gated, include test coverage, and fix pre-existing f-string bugs. The logic correctly enables a new feature path without affecting existing behavior for older cuDNN versions.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/attention/test_attention.py 5/5 Added test for THD format with sink attention (non-vanilla softmax types), gated by cuDNN 9.18.0+ requirement
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Conditionally enables FusedAttention for THD format with sink attention on cuDNN >= 9.18.0 by removing blanket disablement
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 5/5 Fixed f-string formatting in assertions and added version-gated check for THD + sink attention support with context parallelism

Sequence Diagram

sequenceDiagram
    participant Test as test_dpa_softmax_thd
    participant DPA as DotProductAttention
    participant Backend as get_attention_backend
    participant Version as get_cudnn_version
    participant CP as attn_forward_func_with_cp
    participant Fused as FusedAttention
    
    Test->>Version: Check cuDNN version
    alt cuDNN < 9.18.0
        Version-->>Test: Skip test
    else cuDNN >= 9.18.0
        Test->>DPA: Run with THD format + sink attention
        DPA->>Backend: Determine backend (softmax_type, qkv_format="thd")
        Backend->>Version: get_cudnn_version()
        Version-->>Backend: Return version
        
        alt cuDNN < 9.18.0
            Backend->>Backend: Disable FusedAttention for THD
            Backend-->>DPA: Use alternate backend
        else cuDNN >= 9.18.0
            Backend->>Backend: Keep FusedAttention enabled
            Backend-->>DPA: Use FusedAttention
        end
        
        alt context_parallel enabled
            DPA->>CP: attn_forward_func_with_cp
            CP->>Version: get_cudnn_version()
            Version-->>CP: Return version
            
            alt cuDNN < 9.18.0 && softmax_type != "vanilla" && qkv_format == "thd"
                CP->>CP: Assertion fails
                CP-->>DPA: Error
            else cuDNN >= 9.18.0 || valid config
                CP->>Fused: Execute attention with CP
                Fused-->>CP: Result
                CP-->>DPA: Return output
            end
        else no context_parallel
            DPA->>Fused: Execute attention
            Fused-->>DPA: Return output
        end
        
        DPA-->>Test: Test passes
    end
Loading

@KshitijLakhani KshitijLakhani changed the title Update THD sink attention logic for cudnn >=9.18.0 [PyT] Update THD sink attention logic for cudnn >=9.18.0 Jan 8, 2026
@cuichenx cuichenx requested a review from cyanguwa January 8, 2026 23:43
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Enables THD (token-head-dimension) format support for sink attention (non-vanilla softmax types like off-by-one and learnable) with cuDNN 9.18.0+. Updates backend selection logic in utils.py to conditionally enable FusedAttention based on cuDNN version, and adds version-gated assertion in context_parallel.py to allow the feature on newer cuDNN versions. Includes test coverage for the new functionality and fixes f-string formatting in several assertion messages.

Confidence Score: 4/5

  • Safe to merge with minor considerations for testing coverage
  • The changes are well-structured and properly gated by version checks. The logic correctly enables THD sink attention for cuDNN >= 9.18.0 while maintaining backward compatibility. The f-string fixes improve code quality. However, the removed code that disabled UnfusedDotProductAttention for THD format with non-vanilla softmax may allow fallback to UnfusedDotProductAttention on older cuDNN versions, which wasn't explicitly tested.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py 4/5 Updates THD sink attention backend selection to conditionally enable FusedAttention for cuDNN >= 9.18.0; removes redundant UnfusedDotProductAttention disabling logic
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 4/5 Adds version check to allow THD format with non-vanilla softmax for cuDNN >= 9.18.0 in context parallelism; fixes f-string formatting in assertion messages

Sequence Diagram

sequenceDiagram
    participant App as Application
    participant Utils as Backend Selection (utils.py)
    participant CP as Context Parallel (context_parallel.py)
    
    App->>Utils: get_attention_backend(softmax_type, qkv_format)
    
    alt softmax_type != vanilla AND qkv_format == thd
        Utils->>Utils: Check cuDNN version
        alt cuDNN >= 9.18.0
            Utils-->>App: FusedAttention enabled
        else cuDNN < 9.18.0
            Utils-->>App: FusedAttention disabled
        end
    else other configurations
        Utils-->>App: Standard backend selection
    end
    
    App->>CP: attn_forward_func_with_cp()
    CP->>CP: Validate THD sink attention support
    alt cuDNN >= 9.18.0
        CP-->>App: THD sink attention allowed
    else cuDNN < 9.18.0
        CP-->>App: Assert error if THD + non-vanilla softmax
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR enables THD (Thread-Hierarchical Decomposition) format support for sink attention (off-by-one and learnable softmax types) when using cuDNN 9.18.0 or newer.

Key changes:

  • Fixed f-string formatting bugs in error messages (missing f prefix on 5 assertions)
  • Conditionally enables FusedAttention backend for THD + sink attention when cuDNN >= 9.18.0
  • Updates context parallelism validation to allow THD + sink attention with cuDNN >= 9.18.0 (requires cp_comm_type='a2a')
  • Removes redundant checks for UnfusedDotProductAttention (already covered by general context parallelism filter)
  • Adds test coverage for THD format with various softmax types

The changes are well-structured and maintain backward compatibility by keeping restrictions in place for older cuDNN versions.

Confidence Score: 5/5

  • This PR is safe to merge with no blocking issues found
  • The changes are well-implemented with proper version gating, fix existing bugs (f-string formatting), remove redundant code, and include test coverage. All logic paths were verified to be consistent across files, and the conditional enablement of features based on cuDNN version is correctly implemented.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
tests/pytorch/attention/test_attention.py 5/5 Added test for THD format with sink attention (off-by-one/learnable softmax), properly gated by cuDNN 9.18.0+ requirement
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 5/5 Fixed f-string formatting bugs and added conditional version check to allow THD sink attention with cuDNN 9.18.0+
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Conditionally enables FusedAttention for THD + sink attention with cuDNN 9.18.0+, removes redundant UnfusedDotProductAttention checks

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention
    participant Backend as get_attention_backend
    participant CP as attn_forward_func_with_cp
    participant FusedAttn as FusedAttention

    User->>DPA: Call with thd format + sink attention
    DPA->>Backend: Check backend support
    
    alt cuDNN >= 9.18.0
        Backend->>Backend: Allow FusedAttention for thd + sink
        Backend-->>DPA: FusedAttention enabled
        
        alt Context Parallelism
            DPA->>CP: Forward with cp_comm_type=a2a
            CP->>CP: Validate: softmax_type requires a2a
            CP->>CP: Validate: thd + sink OK for cuDNN >= 9.18.0
            CP->>FusedAttn: Execute attention with sink
            FusedAttn-->>CP: Results
            CP-->>DPA: Output
        else No Context Parallelism
            DPA->>FusedAttn: Execute attention with sink
            FusedAttn-->>DPA: Output
        end
    else cuDNN < 9.18.0
        Backend->>Backend: Disable FusedAttention for thd + sink
        Backend-->>DPA: Fallback to FlashAttention
        DPA->>DPA: Execute with FlashAttention
    end
    
    DPA-->>User: Attention output
Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants