Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,70 @@ def codegen_store(
) -> ast.AST:
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
name = state.device_function.tensor_arg(fake_tensor).name

# Check if the pointer is effectively scalar but the value has dimensions.
# This happens when all block-indexed dimensions have size 1 in the target tensor.
# In this case, we need to reshape the value to scalar to match the pointer.
env = CompileEnvironment.current()
output_size = SubscriptIndexing.compute_shape(fake_tensor, subscript, state)

# Determine if pointer has any block dimensions by checking if any block index
# targets a non-size-1 tensor dimension. We need to match the logic in
# SubscriptIndexing.create which skips dimensions where fake_tensor.size(i) == 1.
pointer_has_block_dims = False
tensor_dim = 0
k_index = 0
for k in subscript:
if k is None:
# None adds a dimension to output, not from tensor
pass
elif isinstance(k, int):
# Scalar int index - consumes tensor dim but adds scalar to pointer
tensor_dim += 1
elif _get_tile_with_offset_info(
k, state, k_index
) is not None or isinstance(k, torch.Tensor):
# Tensor index (tile.index + offset or regular tensor) - block index
if not env.known_equal(fake_tensor.size(tensor_dim), 1):
pointer_has_block_dims = True
tensor_dim += 1
k_index += 1
elif isinstance(k, torch.SymInt):
# SymInt can be block index (with BlockSizeOrigin) or scalar
symbol = k._sympy_()
origin = None
if isinstance(symbol, sympy.Symbol):
origin = HostFunction.current().expr_to_origin.get(symbol)
if origin and isinstance(origin.origin, BlockSizeOrigin):
# Block index
if not env.known_equal(fake_tensor.size(tensor_dim), 1):
pointer_has_block_dims = True
# Both block and scalar SymInt consume a tensor dimension
tensor_dim += 1
k_index += 1
elif isinstance(k, slice):
# Slice - adds block dimension if slice_size > 1
size = fake_tensor.size(tensor_dim)
slice_size = compute_slice_size(k, size)
if not env.known_equal(slice_size, 1):
if not env.known_equal(fake_tensor.size(tensor_dim), 1):
pointer_has_block_dims = True
tensor_dim += 1
k_index += 1

# If pointer is scalar but output_size has dimensions, reshape value to scalar.
# Skip reshaping for scalar constants which don't have shape.
if (
not pointer_has_block_dims
and output_size
and not isinstance(value, ast.Constant)
):
# Pointer is scalar but value may have shape - squeeze to scalar
value = expr_from_string(
"tl.reshape({value}, [])",
value=value,
)

return expr_from_string(
f"tl.store({name} + {{offset}}, {{value}}, {{mask}})",
value=value,
Expand Down
2 changes: 1 addition & 1 deletion test/test_associative_scan.expected
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def _helion_test_single_element(x, result):
row_data = tl.load(x + tl.zeros([1, 1], tl.int32), None)
# src[test_associative_scan.py:N]: result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1)
_associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0)
tl.store(result + tl.zeros([1, 1], tl.int32), _associative_scan, None)
tl.store(result + tl.zeros([1, 1], tl.int32), tl.reshape(_associative_scan, []), None)

def test_single_element(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_associative_scan.py:N]: result = torch.empty_like(x)
Expand Down
2 changes: 2 additions & 0 deletions test/test_associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfCpu
from helion._testing import skipIfRefEager
import helion.language as hl

Expand Down Expand Up @@ -381,6 +382,7 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
# Verify reverse parameter is in generated code
self.assertIn("reverse=True", code)

@skipIfCpu("")
def test_associative_scan_edge_cases(self):
"""Test associative_scan edge cases."""

Expand Down
69 changes: 69 additions & 0 deletions test/test_indexing.expected
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,75 @@ def masked_store(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_indexing.py:N]: return out
return out

--- assertExpectedJournal(TestIndexing.test_mixed_scalar_block_store_size1_dim)
from __future__ import annotations

import torch
import helion.language as hl
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel_with_mixed_store(x_data, out, scales, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
num_blocks_0 = 1
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_1 * _BLOCK_SIZE_0
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
tile_end = offset_0 + _BLOCK_SIZE_0
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
# src[test_indexing.py:N]: ):
# src[test_indexing.py:N-N]: ...
for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < tile_end
# src[test_indexing.py:N]: x_block = x_data[m_tile, n_tile_local]
x_block = tl.load(x_data + indices_2[None, :] * 1, mask_2[None, :], other=0)
# src[test_indexing.py:N]: row_max = x_block.abs().amax(dim=1)
v_0 = tl_math.abs(x_block)
_mask_to = tl.where(tl.broadcast_to(mask_2[None, :], [1, _BLOCK_SIZE_2]), v_0, tl.full([], float('-inf'), tl.float32))
row_max = tl.cast(tl.max(_mask_to, 1), tl.float32)
# src[test_indexing.py:N]: row_value = row_max.to(torch.uint8)
v_1 = tl.cast(row_max, tl.uint8)
# src[test_indexing.py:N]: out[m_tile, n_tile_local] = x_block * 2.0
v_2 = 2.0
v_3 = x_block * v_2
tl.store(out + indices_2[None, :] * 1, v_3, mask_2[None, :])
# src[test_indexing.py:N]: scale_col_idx = n_tile_local.begin // BLOCK_SIZE # scalar
floordiv = triton_helpers.div_floor_integer(offset_2, 32)
# src[test_indexing.py:N]: scales[m_tile, scale_col_idx] = row_value # row_value is block
tl.store(scales + floordiv * 1, tl.reshape(v_1, []), None)

def kernel_with_mixed_store(x_data: torch.Tensor, BLOCK_SIZE: hl.constexpr, *, _launcher=_default_launcher):
# src[test_indexing.py:N]: m, n = x_data.shape
m, n = x_data.shape
# src[test_indexing.py:N]: n = hl.specialize(n)
n = 64
# src[test_indexing.py:N]: n_scale_cols = (n + BLOCK_SIZE - 1) // BLOCK_SIZE
n_scale_cols = (n + 32 - 1) // 32
# src[test_indexing.py:N]: scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
# src[test_indexing.py:N]: out = x_data.new_empty(x_data.shape, dtype=torch.float32)
out = x_data.new_empty(x_data.shape, dtype=torch.float32)
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
_BLOCK_SIZE_0 = 32
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
# src[test_indexing.py:N]: ):
# src[test_indexing.py:N-N]: ...
_BLOCK_SIZE_2 = 32
# src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
# src[test_indexing.py:N]: for n_tile_local in hl.tile(
# src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
# src[test_indexing.py:N-N]: ...
_launcher(_helion_kernel_with_mixed_store, (1 * triton.cdiv(64, _BLOCK_SIZE_0),), x_data, out, scales, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
# src[test_indexing.py:N]: return out, scales
return (out, scales)

--- assertExpectedJournal(TestIndexing.test_non_consecutive_tensor_indexers_no_broadcast)
from __future__ import annotations

Expand Down
52 changes: 52 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ def test_reduction_tensor_descriptor_indexing_reduction_loop(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfCpu("")
def test_2d_slice_index(self):
"""Test both setter from scalar and getter for [:,i]"""

Expand Down Expand Up @@ -2161,6 +2162,57 @@ def store_with_mixed_indices(
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfCpu("")
def test_mixed_scalar_block_store_size1_dim(self):
"""Test store with mixed scalar/block indexing when block dimension has size 1.

This tests a bug fix where storing a block value with:
- One index being a tile/block (e.g., m_tile) over a size-1 dimension
- Another index being a scalar (e.g., computed from tile.begin)
would generate invalid Triton code because the pointer became scalar
but the value was still a block.
"""

@helion.kernel(autotune_effort="none")
def kernel_with_mixed_store(
x_data: torch.Tensor, BLOCK_SIZE: hl.constexpr
) -> tuple[torch.Tensor, torch.Tensor]:
m, n = x_data.shape
n = hl.specialize(n)
n_scale_cols = (n + BLOCK_SIZE - 1) // BLOCK_SIZE
scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
out = x_data.new_empty(x_data.shape, dtype=torch.float32)

n_block = hl.register_block_size(BLOCK_SIZE, n)

for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
for n_tile_local in hl.tile(
n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
):
x_block = x_data[m_tile, n_tile_local]

# Compute one value per row in m_tile
row_max = x_block.abs().amax(dim=1)
row_value = row_max.to(torch.uint8)

out[m_tile, n_tile_local] = x_block * 2.0

# Mixed indexing: block row index + scalar column index
scale_col_idx = n_tile_local.begin // BLOCK_SIZE # scalar
scales[m_tile, scale_col_idx] = row_value # row_value is block

return out, scales

# Test with m=1 (single row - this was the failing case before the fix)
# The fix ensures tl.reshape is applied to squeeze the value to scalar
# when the pointer is scalar due to size-1 dimensions being dropped.
x1 = torch.randn(1, 64, device=DEVICE, dtype=torch.float32)
code, (out1, scales1) = code_and_output(kernel_with_mixed_store, (x1, 32))
expected_out1 = x1 * 2.0
torch.testing.assert_close(out1, expected_out1)
self.assertEqual(scales1.shape, (1, 2))
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/test_reduce.expected
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _helion_test_reduce_codegen_kernel(x, result, _RDIM_SIZE_1: tl.constexpr):
row_data = tl.load(x + indices_1[None, :] * 1, mask_1[None, :], other=0)
# src[test_reduce.py:N]: result[i] = hl.reduce(add_combine_fn, row_data, dim=1)
_reduce = tl.reduce(row_data, 1, add_combine_fn_0)
tl.store(result + tl.zeros([1], tl.int32), _reduce, None)
tl.store(result + tl.zeros([1], tl.int32), tl.reshape(_reduce, []), None)

def test_reduce_codegen_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_reduce.py:N]: result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device)
Expand Down
2 changes: 2 additions & 0 deletions test/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfCpu
import helion.language as hl


Expand Down Expand Up @@ -500,6 +501,7 @@ def test_argmax_negative_kernel(
self.assertIn("tl.reduce", code)
self.assertIn("argmax_combine_fn_", code)

@skipIfCpu("")
def test_reduce_code_generation(self):
"""Test that reduce generates correct Triton code."""

Expand Down