Skip to content

Commit 3d40d9d

Browse files
committed
Fix invalid Triton code for mixed scalar/block indexing in store operations when block dimension has size 1
Fixes #1256 stack-info: PR: #1258, branch: oulgen/stack/186
1 parent 28cc903 commit 3d40d9d

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,42 @@ def codegen_store(
179179
) -> ast.AST:
180180
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
181181
name = state.device_function.tensor_arg(fake_tensor).name
182+
183+
# Compute the effective pointer shape (dimensions that contribute to the offset).
184+
# Dimensions where fake_tensor.size(i) == 1 are skipped in the offset computation,
185+
# so we need to reshape the value to match the effective pointer shape.
186+
env = CompileEnvironment.current()
187+
output_size = SubscriptIndexing.compute_shape(fake_tensor, subscript, state)
188+
189+
# Compute the effective shape after dropping size-1 tensor dimensions
190+
# This matches the logic in SubscriptIndexing.create that skips size-1 dims
191+
effective_shape = []
192+
tensor_dim = 0
193+
for k in subscript:
194+
if k is None:
195+
# None adds a dimension of size 1 to output, not from tensor
196+
pass
197+
elif isinstance(k, int):
198+
# Scalar int eliminates the dimension
199+
tensor_dim += 1
200+
elif isinstance(k, (torch.SymInt, torch.Tensor, slice)):
201+
# These consume a tensor dimension
202+
if not env.known_equal(fake_tensor.size(tensor_dim), 1):
203+
# This dimension contributes to the pointer
204+
# Find corresponding output dimension
205+
if tensor_dim < len(output_size):
206+
effective_shape.append(output_size[tensor_dim])
207+
tensor_dim += 1
208+
209+
# If effective_shape is empty but output_size is not all-1s, we need to reshape
210+
# the value to be scalar. Skip reshaping for scalar constants which don't have shape.
211+
if not effective_shape and output_size and not isinstance(value, ast.Constant):
212+
# Pointer is scalar but value may have shape - squeeze to scalar
213+
value = expr_from_string(
214+
"tl.reshape({value}, [])",
215+
value=value,
216+
)
217+
182218
return expr_from_string(
183219
f"tl.store({name} + {{offset}}, {{value}}, {{mask}})",
184220
value=value,

test/test_indexing.expected

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,75 @@ def masked_store(x: torch.Tensor, *, _launcher=_default_launcher):
759759
# src[test_indexing.py:N]: return out
760760
return out
761761

762+
--- assertExpectedJournal(TestIndexing.test_mixed_scalar_block_store_size1_dim)
763+
from __future__ import annotations
764+
765+
import torch
766+
import helion.language as hl
767+
import triton
768+
import triton.language as tl
769+
from torch._inductor.runtime import triton_helpers
770+
from torch._inductor.runtime.triton_helpers import math as tl_math
771+
from helion.runtime import default_launcher as _default_launcher
772+
773+
@triton.jit
774+
def _helion_kernel_with_mixed_store(x_data, out, scales, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
775+
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
776+
num_blocks_0 = 1
777+
pid_1 = tl.program_id(0) // num_blocks_0
778+
offset_0 = pid_1 * _BLOCK_SIZE_0
779+
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
780+
tile_end = offset_0 + _BLOCK_SIZE_0
781+
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
782+
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
783+
# src[test_indexing.py:N]: ):
784+
# src[test_indexing.py:N-N]: ...
785+
for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_2):
786+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
787+
mask_2 = indices_2 < tile_end
788+
# src[test_indexing.py:N]: x_block = x_data[m_tile, n_tile_local]
789+
x_block = tl.load(x_data + indices_2[None, :] * 1, mask_2[None, :], other=0)
790+
# src[test_indexing.py:N]: row_max = x_block.abs().amax(dim=1)
791+
v_0 = tl_math.abs(x_block)
792+
_mask_to = tl.where(tl.broadcast_to(mask_2[None, :], [1, _BLOCK_SIZE_2]), v_0, tl.full([], float('-inf'), tl.float32))
793+
row_max = tl.cast(tl.max(_mask_to, 1), tl.float32)
794+
# src[test_indexing.py:N]: row_value = row_max.to(torch.uint8)
795+
v_1 = tl.cast(row_max, tl.uint8)
796+
# src[test_indexing.py:N]: out[m_tile, n_tile_local] = x_block * 2.0
797+
v_2 = 2.0
798+
v_3 = x_block * v_2
799+
tl.store(out + indices_2[None, :] * 1, v_3, mask_2[None, :])
800+
# src[test_indexing.py:N]: scale_col_idx = n_tile_local.begin // BLOCK_SIZE # scalar
801+
floordiv = triton_helpers.div_floor_integer(offset_2, 32)
802+
# src[test_indexing.py:N]: scales[m_tile, scale_col_idx] = row_value # row_value is block
803+
tl.store(scales + floordiv * 1, tl.reshape(v_1, []), None)
804+
805+
def kernel_with_mixed_store(x_data: torch.Tensor, BLOCK_SIZE: hl.constexpr, *, _launcher=_default_launcher):
806+
# src[test_indexing.py:N]: m, n = x_data.shape
807+
m, n = x_data.shape
808+
# src[test_indexing.py:N]: n = hl.specialize(n)
809+
n = 64
810+
# src[test_indexing.py:N]: n_scale_cols = (n + BLOCK_SIZE - 1) // BLOCK_SIZE
811+
n_scale_cols = (n + 32 - 1) // 32
812+
# src[test_indexing.py:N]: scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
813+
scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
814+
# src[test_indexing.py:N]: out = x_data.new_empty(x_data.shape, dtype=torch.float32)
815+
out = x_data.new_empty(x_data.shape, dtype=torch.float32)
816+
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
817+
_BLOCK_SIZE_0 = 32
818+
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
819+
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
820+
# src[test_indexing.py:N]: ):
821+
# src[test_indexing.py:N-N]: ...
822+
_BLOCK_SIZE_2 = 32
823+
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
824+
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
825+
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
826+
# src[test_indexing.py:N-N]: ...
827+
_launcher(_helion_kernel_with_mixed_store, (1 * triton.cdiv(64, _BLOCK_SIZE_0),), x_data, out, scales, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
828+
# src[test_indexing.py:N]: return out, scales
829+
return (out, scales)
830+
762831
--- assertExpectedJournal(TestIndexing.test_non_consecutive_tensor_indexers_no_broadcast)
763832
from __future__ import annotations
764833

test/test_indexing.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,56 @@ def store_with_mixed_indices(
21612161
torch.testing.assert_close(result, expected)
21622162
self.assertExpectedJournal(code)
21632163

2164+
def test_mixed_scalar_block_store_size1_dim(self):
2165+
"""Test store with mixed scalar/block indexing when block dimension has size 1.
2166+
2167+
This tests a bug fix where storing a block value with:
2168+
- One index being a tile/block (e.g., m_tile) over a size-1 dimension
2169+
- Another index being a scalar (e.g., computed from tile.begin)
2170+
would generate invalid Triton code because the pointer became scalar
2171+
but the value was still a block.
2172+
"""
2173+
2174+
@helion.kernel(autotune_effort="none")
2175+
def kernel_with_mixed_store(
2176+
x_data: torch.Tensor, BLOCK_SIZE: hl.constexpr
2177+
) -> tuple[torch.Tensor, torch.Tensor]:
2178+
m, n = x_data.shape
2179+
n = hl.specialize(n)
2180+
n_scale_cols = (n + BLOCK_SIZE - 1) // BLOCK_SIZE
2181+
scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
2182+
out = x_data.new_empty(x_data.shape, dtype=torch.float32)
2183+
2184+
n_block = hl.register_block_size(BLOCK_SIZE, n)
2185+
2186+
for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
2187+
for n_tile_local in hl.tile(
2188+
n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
2189+
):
2190+
x_block = x_data[m_tile, n_tile_local]
2191+
2192+
# Compute one value per row in m_tile
2193+
row_max = x_block.abs().amax(dim=1)
2194+
row_value = row_max.to(torch.uint8)
2195+
2196+
out[m_tile, n_tile_local] = x_block * 2.0
2197+
2198+
# Mixed indexing: block row index + scalar column index
2199+
scale_col_idx = n_tile_local.begin // BLOCK_SIZE # scalar
2200+
scales[m_tile, scale_col_idx] = row_value # row_value is block
2201+
2202+
return out, scales
2203+
2204+
# Test with m=1 (single row - this was the failing case before the fix)
2205+
# The fix ensures tl.reshape is applied to squeeze the value to scalar
2206+
# when the pointer is scalar due to size-1 dimensions being dropped.
2207+
x1 = torch.randn(1, 64, device=DEVICE, dtype=torch.float32)
2208+
code, (out1, scales1) = code_and_output(kernel_with_mixed_store, (x1, 32))
2209+
expected_out1 = x1 * 2.0
2210+
torch.testing.assert_close(out1, expected_out1)
2211+
self.assertEqual(scales1.shape, (1, 2))
2212+
self.assertExpectedJournal(code)
2213+
21642214

21652215
if __name__ == "__main__":
21662216
unittest.main()

0 commit comments

Comments
 (0)