Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_associative_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfMTIA
from helion._testing import skipIfRefEager

import helion.language as hl


Expand Down Expand Up @@ -99,6 +101,7 @@ def jit_add_combine_fn(x, y):


class TestAssociativeScan(RefEagerTestBase, TestCase):
@skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirements")
def test_associative_scan_basic_addition(self):
"""Test basic associative_scan functionality with prefix sum."""

Expand Down Expand Up @@ -132,6 +135,7 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("param_0 + param_1", code)
self.assertIn("tl.associative_scan", code)

@skipIfMTIA("MTIA error: specify_sizes must have a size for each dimension of the input memref")
def test_associative_scan_maximum(self):
"""Test associative_scan with maximum combine function."""

Expand Down Expand Up @@ -164,6 +168,7 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor:
"tl.maximum" in code or "triton_helpers.maximum" in code
)

@skipIfMTIA("MTIA error: In Triton-MTIA, subviews are handled case by case. Only a subset of operations support subviews. Consider extending support for subviews in the compiler or revisit the algorithm.")
def test_associative_scan_multiplication(self):
"""Test associative_scan with multiplication combine function."""

Expand Down Expand Up @@ -194,6 +199,7 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor:
# Verify the generated code contains multiplication
self.assertIn("param_0 * param_1", code)

@skipIfMTIA("MTIA error: MLIR compilation failed: specify_sizes must have a size for each dimension of the input memref")
def test_associative_scan_minimum(self):
"""Test associative_scan with minimum combine function."""

Expand Down Expand Up @@ -226,6 +232,7 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor:
"tl.minimum" in code or "triton_helpers.minimum" in code
)

@skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirementss")
def test_associative_scan_multiple_functions(self):
"""Test using multiple different combine functions in one kernel."""

Expand Down Expand Up @@ -286,6 +293,7 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor:
# Use relaxed tolerance for large tensors due to accumulated floating-point errors
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)

@skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirements")
def test_associative_scan_different_dtypes(self):
"""Test associative_scan with different data types."""

Expand Down Expand Up @@ -320,6 +328,7 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
expected = expected.to(result.dtype)
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_different_sizes(self):
"""Test associative_scan with different tensor sizes."""

Expand Down Expand Up @@ -356,6 +365,7 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor:
expected = torch.cumsum(x, dim=1)
torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_reverse(self):
"""Test associative_scan with reverse=True parameter."""

Expand All @@ -381,6 +391,7 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
# Verify reverse parameter is in generated code
self.assertIn("reverse=True", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_edge_cases(self):
"""Test associative_scan edge cases."""

Expand Down Expand Up @@ -431,6 +442,7 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertEqual(result.shape, x.shape)
self.assertEqual(result.dtype, x.dtype)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_torch_hops_mapping(self):
"""Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan."""

Expand Down Expand Up @@ -466,6 +478,7 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("tl.associative_scan", code)
self.assertIn("param_0 + param_1", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_code_generation(self):
"""Test that the generated code structure is correct."""

Expand Down Expand Up @@ -498,6 +511,7 @@ def test_codegen_kernel(x: torch.Tensor) -> torch.Tensor:
@skipIfRefEager(
"torch._higher_order_ops.associative_scan with nested @helion.kernel is not supported by ref eager mode yet"
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_jit_decorator_ignored(self):
"""Test that @helion.kernel decorator on combine functions is ignored."""

Expand Down Expand Up @@ -527,6 +541,7 @@ def test_jit_kernel(x: torch.Tensor) -> torch.Tensor:
@skipIfRefEager(
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_tuple_args(self):
"""Test associative_scan with tuple arguments (matching GitHub issue #237 pattern)."""

Expand Down Expand Up @@ -579,6 +594,7 @@ def test_segmented_kernel(
@skipIfRefEager(
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_segmented_reduction(self):
"""Test associative_scan for segmented reduction use case."""

Expand Down Expand Up @@ -636,6 +652,7 @@ def segmented_scan_kernel(
@skipIfRefEager(
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
)
@skipIfMTIA("MTIA error: specify_sizes must have a size for each dimension of the input memref")
def test_associative_scan_cumulative_argmax(self):
"""Test cumulative argmax using tuple args with (float, int) types."""

Expand Down Expand Up @@ -705,6 +722,7 @@ def cumulative_argmax_kernel(
self.assertIn("def argmax_combine_fn_", code)
self.assertIn("tl.associative_scan", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_in_helper_function(self):
"""Test calling a function that internally uses hl.associative_scan."""

Expand Down Expand Up @@ -739,6 +757,7 @@ def test_helper_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("tl.associative_scan", code)
self.assertIn("param_0 + param_1", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumsum_basic(self):
"""Test basic cumsum functionality."""

Expand Down Expand Up @@ -766,6 +785,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("param_0 + param_1", code)
self.assertIn("tl.associative_scan", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumsum_reverse(self):
"""Test cumsum with reverse=True."""

Expand All @@ -789,6 +809,7 @@ def test_cumsum_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
# Verify reverse parameter is used
self.assertIn("reverse=True", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumsum_different_dtypes(self):
"""Test cumsum with different data types."""

Expand Down Expand Up @@ -820,6 +841,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
expected = expected.to(result.dtype)
torch.testing.assert_close(result, expected)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumprod_basic(self):
"""Test basic cumprod functionality."""

Expand Down Expand Up @@ -847,6 +869,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor:
self.assertIn("param_0 * param_1", code)
self.assertIn("tl.associative_scan", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumprod_reverse(self):
"""Test cumprod with reverse=True."""

Expand All @@ -870,6 +893,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor:
# Verify reverse parameter is used
self.assertIn("reverse=True", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumprod_different_dtypes(self):
"""Test cumprod with different data types."""

Expand Down Expand Up @@ -901,6 +925,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor) -> torch.Tensor:
expected = expected.to(result.dtype)
torch.testing.assert_close(result, expected)

@skipIfMTIA("Not supported on MTIA yet.")
def test_cumsum_cumprod_mixed(self):
"""Test using both cumsum and cumprod in the same kernel."""

Expand Down Expand Up @@ -937,6 +962,7 @@ def test_mixed_kernel(x: torch.Tensor) -> torch.Tensor:
@skipIfRefEager(
"torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet"
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_tuple_format(self):
"""Test associative_scan with tuple format combine function (like reduce format)."""

Expand Down Expand Up @@ -988,6 +1014,7 @@ def test_segmented_tuple_kernel(
self.assertIn("def helion_combine_tuple_fn_", code)
self.assertIn("tl.associative_scan", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_associative_scan_argmax_tuple_format(self):
"""Test cumulative argmax using tuple format combine function."""

Expand Down
16 changes: 16 additions & 0 deletions test/test_atomic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfMTIA
from helion._testing import skipIfRefEager
from helion._testing import skipIfRocm
import helion.language as hl
Expand Down Expand Up @@ -131,6 +132,7 @@ def atomic_cas_kernel(


class TestAtomicOperations(RefEagerTestBase, TestCase):
@skipIfMTIA("Not supported on MTIA yet.")
def test_basic_atomic_add(self):
x = torch.zeros(10, device=DEVICE)
y = torch.ones(10, device=DEVICE)
Expand All @@ -146,6 +148,7 @@ def test_basic_atomic_add(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_add_1d_tensor(self):
M, N = 32, 64
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
Expand All @@ -162,6 +165,7 @@ def test_atomic_add_1d_tensor(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_add_returns_prev(self):
@helion.kernel()
def k(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -178,6 +182,7 @@ def k(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
torch.testing.assert_close(prev, torch.zeros_like(x))
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_overlapping_atomic_add(self):
# Test with overlapping indices
x = torch.zeros(5, device=DEVICE)
Expand All @@ -195,6 +200,7 @@ def test_overlapping_atomic_add(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_2d_atomic_add(self):
"""Test atomic_add with 2D tensor indexing."""
x = torch.zeros(3, 4, device=DEVICE)
Expand All @@ -211,6 +217,7 @@ def test_2d_atomic_add(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_add_code_generation(self):
"""Test that the generated code contains atomic_add."""
x = torch.zeros(10, device=DEVICE)
Expand All @@ -222,6 +229,7 @@ def test_atomic_add_code_generation(self):
torch.testing.assert_close(result, expected)
self.assertIn("atomic_add", code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_add_float(self):
"""Test that atomic_add works with float constants."""
x = torch.zeros(5, device=DEVICE, dtype=torch.float32)
Expand Down Expand Up @@ -263,6 +271,7 @@ def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
@skipIfRefEager(
"Test is block size dependent which is not supported in ref eager mode"
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_add_w_tile_attr(self):
"""Test atomic_add where the index is a symbolic int"""
x = torch.randn(20, device=DEVICE)
Expand Down Expand Up @@ -293,6 +302,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
kernel(x)

# New tests for other atomics (correctness only; no journal asserts)
@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_and(self):
x0 = torch.full((8,), 0b1111, device=DEVICE, dtype=torch.int32)
y = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32)
Expand All @@ -301,6 +311,7 @@ def test_atomic_and(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_or(self):
x0 = torch.zeros(8, device=DEVICE, dtype=torch.int32)
y = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32)
Expand All @@ -309,6 +320,7 @@ def test_atomic_or(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_xor(self):
x0 = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32)
y = torch.tensor([0b1100] * 8, device=DEVICE, dtype=torch.int32)
Expand All @@ -318,6 +330,7 @@ def test_atomic_xor(self):
self.assertExpectedJournal(code)

@skipIfRocm("ROCm backend currently lacks support for these atomics")
@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_xchg(self):
x0 = torch.zeros(8, device=DEVICE, dtype=torch.int32)
y = torch.arange(8, device=DEVICE, dtype=torch.int32)
Expand All @@ -326,6 +339,7 @@ def test_atomic_xchg(self):
self.assertExpectedJournal(code)

@skipIfRocm("ROCm backend currently lacks support for these atomics")
@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_max(self):
x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32)
y = torch.tensor([4, 2, 9, 1], device=DEVICE, dtype=torch.int32)
Expand All @@ -335,6 +349,7 @@ def test_atomic_max(self):
self.assertExpectedJournal(code)

@skipIfRocm("ROCm backend currently lacks support for these atomics")
@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_min(self):
x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32)
y = torch.tensor([4, 2, 9, 1], device=DEVICE, dtype=torch.int32)
Expand All @@ -343,6 +358,7 @@ def test_atomic_min(self):
torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

@skipIfMTIA("Not supported on MTIA yet.")
def test_atomic_cas(self):
x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32)
expect = torch.tensor([1, 6, 3, 0], device=DEVICE, dtype=torch.int32)
Expand Down
7 changes: 7 additions & 0 deletions test/test_breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from helion import exc
from helion._testing import DEVICE
from helion._testing import RefEagerTestDisabled
from helion._testing import skipIfMTIA
from helion._testing import TestCase
import helion.language as hl

Expand Down Expand Up @@ -92,16 +93,19 @@ def _run_device_breakpoint_test(
out = bound(x)
torch.testing.assert_close(out, x)

@skipIfMTIA("Not supported on MTIA yet.")
def test_device_breakpoint_no_interpret(self) -> None:
self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0)

@unittest.skipUnless(
hasattr(triton_interpreter, "_MISSING"),
"https://github.com/triton-lang/triton/pull/8735",
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_device_breakpoint_triton_interpret(self) -> None:
self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0)

@skipIfMTIA("Not supported on MTIA yet.")
def test_device_breakpoint_helion_interpret(self) -> None:
self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1)

Expand All @@ -122,16 +126,19 @@ def _run_host_breakpoint_test(
out = bound(x)
torch.testing.assert_close(out, x)

@skipIfMTIA("Not supported on MTIA yet.")
def test_host_breakpoint_no_interpret(self) -> None:
self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0)

@unittest.skipUnless(
hasattr(triton_interpreter, "_MISSING"),
"https://github.com/triton-lang/triton/pull/8735",
)
@skipIfMTIA("Not supported on MTIA yet.")
def test_host_breakpoint_triton_interpret(self) -> None:
self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0)

@skipIfMTIA("Not supported on MTIA yet.")
def test_host_breakpoint_helion_interpret(self) -> None:
self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1)

Expand Down
Loading
Loading