Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Contributor

Description

This PR fixes a NaN issue in the fused permute+pad path when handling Float8BlockwiseQTensor inputs.

Since torch.empty does not initialize memory, these buffers could contain NaN values in the padded regions.

When the permute input is a Float8BlockwiseQTensor, if the corresponding permuted_scale entries in the padded region contain NaNs, these NaNs can propagate through the subsequent dequantization and requantization path in GroupedLinear, eventually resulting in a NaN forward loss, e.g.:
ERROR:megatron.core.rerun_state_machine:Unexpected result nan on rank 1 at iteration #2 invocation #1 (message='found NaN in local forward loss calculation')

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Modify permuted_scale initialized to torch.zero in permute_with_mask_map

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 29, 2025

Greptile Summary

Fixed NaN propagation bug in fused permute+pad path by ensuring permuted_scale buffer is zero-initialized when padding is enabled.

  • Changed permuted_scale allocation from torch.empty to use the alloc variable (which becomes torch.zeros when pad_offsets is provided)
  • Previously, uninitialized memory in padded regions could contain NaN values that would propagate through dequantization/requantization in GroupedLinear
  • The fix aligns permuted_scale allocation behavior with output and permuted_probs buffers (lines 163-166), which already use the same pattern
  • This is a minimal, surgical fix that specifically addresses the root cause: the Triton kernel only writes to non-padded positions, so padded positions must be explicitly zeroed

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The fix is a one-line change that directly addresses a critical bug where uninitialized memory caused NaN propagation. The change follows the exact same pattern already used for output and permuted_probs buffers in the same function, ensuring consistency. The fix is minimal, well-justified by the PR description, and has no negative performance or correctness implications
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/triton/permutation.py Fixed uninitialized memory in permuted_scale buffer by using torch.zeros when padding is enabled

Sequence Diagram

sequenceDiagram
    participant Caller
    participant permute_with_mask_map
    participant alloc as torch.zeros/empty
    participant _permute_kernel as Triton Kernel
    participant Float8BlockwiseQTensor

    Caller->>permute_with_mask_map: permute Float8BlockwiseQTensor with padding
    Note over permute_with_mask_map: pad_offsets != None
    permute_with_mask_map->>alloc: alloc = torch.zeros (when padding enabled)
    permute_with_mask_map->>alloc: Allocate permuted_scale buffer
    alloc-->>permute_with_mask_map: Zero-initialized buffer
    permute_with_mask_map->>_permute_kernel: Execute permutation
    Note over _permute_kernel: Writes scales only to<br/>non-padded positions
    Note over _permute_kernel: Padded regions remain<br/>zero (no NaN propagation)
    _permute_kernel-->>permute_with_mask_map: Permuted data + scale
    permute_with_mask_map->>Float8BlockwiseQTensor: Create with permuted_scale
    Float8BlockwiseQTensor-->>Caller: Dequantize/Requantize without NaN
Loading

@xiaoxi-wangfj
Copy link
Contributor Author

@tdophung
Apologies for the oversight. In the previous change, when switching to torch.empty to initialize permuted_scale, I didn’t re-validate the code path where the permute input is a Float8BlockwiseQTensor.
Today, I reran the FP8-Flow setup (with Float8BlockwiseQTensor inputs) and was able to reproduce the issue. This PR fixes the problem, thanks to review.

@tdophung
Copy link
Collaborator

tdophung commented Jan 2, 2026

@tdophung Apologies for the oversight. In the previous change, when switching to torch.empty to initialize permuted_scale, I didn’t re-validate the code path where the permute input is a Float8BlockwiseQTensor. Today, I reran the FP8-Flow setup (with Float8BlockwiseQTensor inputs) and was able to reproduce the issue. This PR fixes the problem, thanks to review.

Thanks for this. I actually also see this nan issue in the jax side when I try not to initialize the permuted scales and eventually fixed it with zerro initialization. I forgot to change the pytorch side accordingly.

@tdophung
Copy link
Collaborator

tdophung commented Jan 2, 2026

/te-ci pytorch

@tdophung tdophung merged commit c988548 into NVIDIA:main Jan 2, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants