diff --git a/test/test_associative_scan.py b/test/test_associative_scan.py index 030def297..74d085ac5 100644 --- a/test/test_associative_scan.py +++ b/test/test_associative_scan.py @@ -9,7 +9,9 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager + import helion.language as hl @@ -99,6 +101,7 @@ def jit_add_combine_fn(x, y): class TestAssociativeScan(RefEagerTestBase, TestCase): + @skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirements") def test_associative_scan_basic_addition(self): """Test basic associative_scan functionality with prefix sum.""" @@ -132,6 +135,7 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("param_0 + param_1", code) self.assertIn("tl.associative_scan", code) + @skipIfMTIA("MTIA error: specify_sizes must have a size for each dimension of the input memref") def test_associative_scan_maximum(self): """Test associative_scan with maximum combine function.""" @@ -164,6 +168,7 @@ def test_max_kernel(x: torch.Tensor) -> torch.Tensor: "tl.maximum" in code or "triton_helpers.maximum" in code ) + @skipIfMTIA("MTIA error: In Triton-MTIA, subviews are handled case by case. Only a subset of operations support subviews. Consider extending support for subviews in the compiler or revisit the algorithm.") def test_associative_scan_multiplication(self): """Test associative_scan with multiplication combine function.""" @@ -194,6 +199,7 @@ def test_mul_kernel(x: torch.Tensor) -> torch.Tensor: # Verify the generated code contains multiplication self.assertIn("param_0 * param_1", code) + @skipIfMTIA("MTIA error: MLIR compilation failed: specify_sizes must have a size for each dimension of the input memref") def test_associative_scan_minimum(self): """Test associative_scan with minimum combine function.""" @@ -226,6 +232,7 @@ def test_min_kernel(x: torch.Tensor) -> torch.Tensor: "tl.minimum" in code or "triton_helpers.minimum" in code ) + @skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirementss") def test_associative_scan_multiple_functions(self): """Test using multiple different combine functions in one kernel.""" @@ -286,6 +293,7 @@ def test_type_kernel(x: torch.Tensor) -> torch.Tensor: # Use relaxed tolerance for large tensors due to accumulated floating-point errors torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + @skipIfMTIA("MTIA error: Expected all tensor inputs to be aligned and/or padded according to the MTIA HW requirements") def test_associative_scan_different_dtypes(self): """Test associative_scan with different data types.""" @@ -320,6 +328,7 @@ def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor: expected = expected.to(result.dtype) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_different_sizes(self): """Test associative_scan with different tensor sizes.""" @@ -356,6 +365,7 @@ def test_size_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.cumsum(x, dim=1) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_reverse(self): """Test associative_scan with reverse=True parameter.""" @@ -381,6 +391,7 @@ def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor: # Verify reverse parameter is in generated code self.assertIn("reverse=True", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_edge_cases(self): """Test associative_scan edge cases.""" @@ -431,6 +442,7 @@ def test_large_kernel(x: torch.Tensor) -> torch.Tensor: self.assertEqual(result.shape, x.shape) self.assertEqual(result.dtype, x.dtype) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_torch_hops_mapping(self): """Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan.""" @@ -466,6 +478,7 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("tl.associative_scan", code) self.assertIn("param_0 + param_1", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_code_generation(self): """Test that the generated code structure is correct.""" @@ -498,6 +511,7 @@ def test_codegen_kernel(x: torch.Tensor) -> torch.Tensor: @skipIfRefEager( "torch._higher_order_ops.associative_scan with nested @helion.kernel is not supported by ref eager mode yet" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_jit_decorator_ignored(self): """Test that @helion.kernel decorator on combine functions is ignored.""" @@ -527,6 +541,7 @@ def test_jit_kernel(x: torch.Tensor) -> torch.Tensor: @skipIfRefEager( "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_tuple_args(self): """Test associative_scan with tuple arguments (matching GitHub issue #237 pattern).""" @@ -579,6 +594,7 @@ def test_segmented_kernel( @skipIfRefEager( "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_segmented_reduction(self): """Test associative_scan for segmented reduction use case.""" @@ -636,6 +652,7 @@ def segmented_scan_kernel( @skipIfRefEager( "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet" ) + @skipIfMTIA("MTIA error: specify_sizes must have a size for each dimension of the input memref") def test_associative_scan_cumulative_argmax(self): """Test cumulative argmax using tuple args with (float, int) types.""" @@ -705,6 +722,7 @@ def cumulative_argmax_kernel( self.assertIn("def argmax_combine_fn_", code) self.assertIn("tl.associative_scan", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_in_helper_function(self): """Test calling a function that internally uses hl.associative_scan.""" @@ -739,6 +757,7 @@ def test_helper_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("tl.associative_scan", code) self.assertIn("param_0 + param_1", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumsum_basic(self): """Test basic cumsum functionality.""" @@ -766,6 +785,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("param_0 + param_1", code) self.assertIn("tl.associative_scan", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumsum_reverse(self): """Test cumsum with reverse=True.""" @@ -789,6 +809,7 @@ def test_cumsum_reverse_kernel(x: torch.Tensor) -> torch.Tensor: # Verify reverse parameter is used self.assertIn("reverse=True", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumsum_different_dtypes(self): """Test cumsum with different data types.""" @@ -820,6 +841,7 @@ def test_cumsum_dtype_kernel(x: torch.Tensor) -> torch.Tensor: expected = expected.to(result.dtype) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumprod_basic(self): """Test basic cumprod functionality.""" @@ -847,6 +869,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("param_0 * param_1", code) self.assertIn("tl.associative_scan", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumprod_reverse(self): """Test cumprod with reverse=True.""" @@ -870,6 +893,7 @@ def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor: # Verify reverse parameter is used self.assertIn("reverse=True", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumprod_different_dtypes(self): """Test cumprod with different data types.""" @@ -901,6 +925,7 @@ def test_cumprod_dtype_kernel(x: torch.Tensor) -> torch.Tensor: expected = expected.to(result.dtype) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_cumsum_cumprod_mixed(self): """Test using both cumsum and cumprod in the same kernel.""" @@ -937,6 +962,7 @@ def test_mixed_kernel(x: torch.Tensor) -> torch.Tensor: @skipIfRefEager( "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_tuple_format(self): """Test associative_scan with tuple format combine function (like reduce format).""" @@ -988,6 +1014,7 @@ def test_segmented_tuple_kernel( self.assertIn("def helion_combine_tuple_fn_", code) self.assertIn("tl.associative_scan", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_associative_scan_argmax_tuple_format(self): """Test cumulative argmax using tuple format combine function.""" diff --git a/test/test_atomic_ops.py b/test/test_atomic_ops.py index bf6339690..dbe5e4f86 100644 --- a/test/test_atomic_ops.py +++ b/test/test_atomic_ops.py @@ -9,6 +9,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager from helion._testing import skipIfRocm import helion.language as hl @@ -131,6 +132,7 @@ def atomic_cas_kernel( class TestAtomicOperations(RefEagerTestBase, TestCase): + @skipIfMTIA("Not supported on MTIA yet.") def test_basic_atomic_add(self): x = torch.zeros(10, device=DEVICE) y = torch.ones(10, device=DEVICE) @@ -146,6 +148,7 @@ def test_basic_atomic_add(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_add_1d_tensor(self): M, N = 32, 64 x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) @@ -162,6 +165,7 @@ def test_atomic_add_1d_tensor(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_add_returns_prev(self): @helion.kernel() def k(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -178,6 +182,7 @@ def k(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: torch.testing.assert_close(prev, torch.zeros_like(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_overlapping_atomic_add(self): # Test with overlapping indices x = torch.zeros(5, device=DEVICE) @@ -195,6 +200,7 @@ def test_overlapping_atomic_add(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_2d_atomic_add(self): """Test atomic_add with 2D tensor indexing.""" x = torch.zeros(3, 4, device=DEVICE) @@ -211,6 +217,7 @@ def test_2d_atomic_add(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_add_code_generation(self): """Test that the generated code contains atomic_add.""" x = torch.zeros(10, device=DEVICE) @@ -222,6 +229,7 @@ def test_atomic_add_code_generation(self): torch.testing.assert_close(result, expected) self.assertIn("atomic_add", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_add_float(self): """Test that atomic_add works with float constants.""" x = torch.zeros(5, device=DEVICE, dtype=torch.float32) @@ -263,6 +271,7 @@ def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor): @skipIfRefEager( "Test is block size dependent which is not supported in ref eager mode" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_add_w_tile_attr(self): """Test atomic_add where the index is a symbolic int""" x = torch.randn(20, device=DEVICE) @@ -293,6 +302,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor: kernel(x) # New tests for other atomics (correctness only; no journal asserts) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_and(self): x0 = torch.full((8,), 0b1111, device=DEVICE, dtype=torch.int32) y = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32) @@ -301,6 +311,7 @@ def test_atomic_and(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_or(self): x0 = torch.zeros(8, device=DEVICE, dtype=torch.int32) y = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32) @@ -309,6 +320,7 @@ def test_atomic_or(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_xor(self): x0 = torch.tensor([0b1010] * 8, device=DEVICE, dtype=torch.int32) y = torch.tensor([0b1100] * 8, device=DEVICE, dtype=torch.int32) @@ -318,6 +330,7 @@ def test_atomic_xor(self): self.assertExpectedJournal(code) @skipIfRocm("ROCm backend currently lacks support for these atomics") + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_xchg(self): x0 = torch.zeros(8, device=DEVICE, dtype=torch.int32) y = torch.arange(8, device=DEVICE, dtype=torch.int32) @@ -326,6 +339,7 @@ def test_atomic_xchg(self): self.assertExpectedJournal(code) @skipIfRocm("ROCm backend currently lacks support for these atomics") + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_max(self): x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32) y = torch.tensor([4, 2, 9, 1], device=DEVICE, dtype=torch.int32) @@ -335,6 +349,7 @@ def test_atomic_max(self): self.assertExpectedJournal(code) @skipIfRocm("ROCm backend currently lacks support for these atomics") + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_min(self): x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32) y = torch.tensor([4, 2, 9, 1], device=DEVICE, dtype=torch.int32) @@ -343,6 +358,7 @@ def test_atomic_min(self): torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_atomic_cas(self): x = torch.tensor([1, 5, 3, 7], device=DEVICE, dtype=torch.int32) expect = torch.tensor([1, 6, 3, 0], device=DEVICE, dtype=torch.int32) diff --git a/test/test_breakpoint.py b/test/test_breakpoint.py index 52796c7fc..a17ef1097 100644 --- a/test/test_breakpoint.py +++ b/test/test_breakpoint.py @@ -15,6 +15,7 @@ from helion import exc from helion._testing import DEVICE from helion._testing import RefEagerTestDisabled +from helion._testing import skipIfMTIA from helion._testing import TestCase import helion.language as hl @@ -92,6 +93,7 @@ def _run_device_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @skipIfMTIA("Not supported on MTIA yet.") def test_device_breakpoint_no_interpret(self) -> None: self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=0) @@ -99,9 +101,11 @@ def test_device_breakpoint_no_interpret(self) -> None: hasattr(triton_interpreter, "_MISSING"), "https://github.com/triton-lang/triton/pull/8735", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_device_breakpoint_triton_interpret(self) -> None: self._run_device_breakpoint_test(triton_interpret=1, helion_interpret=0) + @skipIfMTIA("Not supported on MTIA yet.") def test_device_breakpoint_helion_interpret(self) -> None: self._run_device_breakpoint_test(triton_interpret=0, helion_interpret=1) @@ -122,6 +126,7 @@ def _run_host_breakpoint_test( out = bound(x) torch.testing.assert_close(out, x) + @skipIfMTIA("Not supported on MTIA yet.") def test_host_breakpoint_no_interpret(self) -> None: self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=0) @@ -129,9 +134,11 @@ def test_host_breakpoint_no_interpret(self) -> None: hasattr(triton_interpreter, "_MISSING"), "https://github.com/triton-lang/triton/pull/8735", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_host_breakpoint_triton_interpret(self) -> None: self._run_host_breakpoint_test(triton_interpret=1, helion_interpret=0) + @skipIfMTIA("Not supported on MTIA yet.") def test_host_breakpoint_helion_interpret(self) -> None: self._run_host_breakpoint_test(triton_interpret=0, helion_interpret=1) diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index fe73d6670..8603b653a 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -10,6 +10,7 @@ from helion._testing import DEVICE from helion._testing import RefEagerTestBase from helion._testing import TestCase +from helion._testing import skipIfMTIA from helion._testing import code_and_output from helion._testing import skipIfRefEager import helion.language as hl @@ -46,16 +47,19 @@ def test_broadcast_no_flatten(self): args = [torch.randn(512, 512, device=DEVICE), torch.randn(512, device=DEVICE)] assert not broadcast_fn.bind(args).config_spec.flatten_loops + @skipIfMTIA("MTIA error: Tensor-likes are not close!") def test_broadcast1(self): code = _check_broadcast_fn( block_sizes=[16, 8], ) self.assertExpectedJournal(code) + @skipIfMTIA("MTIA error: Tensor-likes are not close!") def test_broadcast2(self): code = _check_broadcast_fn(block_size=[16, 8], loop_order=(1, 0)) self.assertExpectedJournal(code) + @skipIfMTIA("No fp32 support on MTIA.") def test_broadcast3(self): code = _check_broadcast_fn( block_sizes=[64, 1], @@ -76,6 +80,7 @@ def test_broadcast5(self): ) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_constexpr_index(self): @helion.kernel def fn(a, idx1): diff --git a/test/test_control_flow.py b/test/test_control_flow.py index 8c8a5f068..7eab14138 100644 --- a/test/test_control_flow.py +++ b/test/test_control_flow.py @@ -11,6 +11,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager import helion.language as hl @@ -41,6 +42,7 @@ def fn(x, v): self.assertEqual(code0, code1) self.assertExpectedJournal(code0) + @skipIfMTIA("Not supported on MTIA yet.") def test_if_arg_indexed_scalar(self): @helion.kernel def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -68,6 +70,7 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @skipIfRefEager( "Test is block size dependent which is not supported in ref eager mode" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_if_arg_tensor_sum(self): @helion.kernel def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -96,6 +99,7 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @skipIfMTIA("Not supported on MTIA yet.") def test_constant_true(self): @helion.kernel( config={ @@ -143,6 +147,7 @@ def fn(x): torch.testing.assert_close(result, torch.sin(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_error_in_non_taken_branch(self): def mul_relu_block_back_spec(x, y, dz): z = torch.relu(x * y[:, None]) diff --git a/test/test_debug_utils.py b/test/test_debug_utils.py index cb0ca65ec..2beb3979b 100644 --- a/test/test_debug_utils.py +++ b/test/test_debug_utils.py @@ -12,8 +12,9 @@ import helion from helion._testing import DEVICE from helion._testing import RefEagerTestDisabled -from helion._testing import TestCase from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA +from helion._testing import TestCase import helion.language as hl @@ -103,6 +104,7 @@ def _extract_repro_script(self, text: str) -> str: # Extract content including both markers return text[start_idx : end_idx + len(end_marker)].strip() + @skipIfMTIA("Not supported on MTIA yet.") def test_print_repro_env_var(self): """Ensure HELION_PRINT_REPRO=1 emits an executable repro script.""" with self._with_print_repro_enabled(): diff --git a/test/test_dot.py b/test/test_dot.py index 9233e694a..e910d8845 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -17,6 +17,7 @@ from helion._testing import code_and_output from helion._testing import is_cuda from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager from helion._testing import skipIfRocm from helion._testing import skipIfXPU @@ -193,6 +194,7 @@ def run_kernel(): class TestDot(RefEagerTestBase, TestCase): @skipIfRefEager("Codegen inspection not applicable in ref eager mode") + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_codegen_acc_differs_uses_addition(self): # Test case 1: fused accumulation (acc_dtype = float32, common dtype = bfloat16) input_dtype = torch.bfloat16 @@ -264,6 +266,7 @@ def dot_kernel_out_dtype(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) self.assertIn("out_dtype=tl.float16", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_torch_matmul_3d(self): @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: @@ -435,6 +438,7 @@ def no_warn_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @skipIfRefEager("Debug dtype codegen checks rely on compiled code") @skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772") @skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.") + @skipIfMTIA("Not supported on MTIA yet.") def test_baddbmm_pipeline_debug_dtype_asserts(self): # Reproduces scripts/repro512.py within the test suite and asserts # the kernel compiles and runs with debug dtype asserts enabled. @@ -705,6 +709,7 @@ def mm_reshape_k_2( expected = (x.view(m, 2) @ y.view(2, n)).to(torch.float32) torch.testing.assert_close(result, expected, rtol=rtol, atol=atol) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_small_m_dim(self): """Test hl.dot with M=2 which is smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -714,6 +719,7 @@ def test_hl_dot_small_m_dim(self): mm_func=lambda acc, a, b: hl.dot(a, b, acc=acc), ) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_small_n_dim(self): """Test hl.dot with N=3 which is smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -723,6 +729,7 @@ def test_hl_dot_small_n_dim(self): mm_func=lambda acc, a, b: hl.dot(a, b, acc=acc), ) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_small_k_dim(self): """Test hl.dot with K=4 which is smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -732,6 +739,7 @@ def test_hl_dot_small_k_dim(self): mm_func=lambda acc, a, b: hl.dot(a, b, acc=acc), ) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_multiple_small_dims(self): """Test hl.dot with multiple dims smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -742,40 +750,49 @@ def test_hl_dot_multiple_small_dims(self): check_code=True, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_small_m_dim(self): """Test torch.addmm with M=2 smaller than the minimum of 16 for tl.dot.""" self._test_small_dims(m_dim=2, k_dim=32, n_dim=64, mm_func=torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_small_n_dim(self): """Test torch.addmm with N=3 smaller than the minimum of 16 for tl.dot.""" self._test_small_dims(m_dim=32, k_dim=64, n_dim=3, mm_func=torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_small_k_dim(self): """Test torch.addmm with K=4 smaller than the minimum of 16 for tl.dot.""" self._test_small_dims(m_dim=32, k_dim=4, n_dim=64, mm_func=torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_multiple_small_dims(self): """Test torch.addmm with multiple dims smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( m_dim=5, k_dim=6, n_dim=7, mm_func=torch.addmm, check_code=True ) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_reshape_m_1(self): """Test torch.addmm with M=1 created through reshape.""" self._test_reshape_m_1(torch.addmm, check_code=True) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_reshape_n_1(self): """Test torch.addmm with N=1 created through reshape.""" self._test_reshape_n_1(torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_reshape_k_1(self): """Test torch.addmm with K=1 created through reshape.""" self._test_reshape_k_1(torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_reshape_k_2(self): """Test torch.addmm with K=2 created through reshape.""" self._test_reshape_k_2(torch.addmm, check_code=True) + @skipIfMTIA("Not supported on MTIA yet.") def _test_reshape_m_2(self, mm_func, *, rtol: float = 1e-2, atol: float = 1e-3): """Test matrix multiplication with M=2 created through reshape.""" @@ -852,34 +869,42 @@ def test_addmm_reshape_m_2(self): """Test torch.addmm with M=2 created through reshape.""" self._test_reshape_m_2(torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_reshape_n_2(self): """Test torch.addmm with N=2 created through reshape.""" self._test_reshape_n_2(torch.addmm) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_m_1(self): """Test hl.dot with M=1 created through reshape.""" self._test_reshape_m_1(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_n_1(self): """Test hl.dot with N=1 created through reshape.""" self._test_reshape_n_1(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_k_1(self): """Test hl.dot with K=1 created through reshape.""" self._test_reshape_k_1(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_k_2(self): """Test hl.dot with K=2 created through reshape.""" self._test_reshape_k_2(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_m_2(self): """Test hl.dot with M=2 created through reshape.""" self._test_reshape_m_2(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_hl_dot_reshape_n_2(self): """Test hl.dot with N=2 created through reshape.""" self._test_reshape_n_2(lambda acc, a, b: hl.dot(a, b, acc=acc)) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_small_m_dim(self): """Test torch.mm with M=2 smaller than the minimum of 16 for tl.dot.""" # Allow slightly larger absolute error for torch.mm small-dim tiles @@ -892,6 +917,7 @@ def test_mm_small_m_dim(self): rtol=1e-2, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_small_n_dim(self): """Test torch.mm with N=3 smaller than the minimum of 16 for tl.dot.""" # Allow slightly larger absolute error for torch.mm small-dim tiles @@ -904,6 +930,7 @@ def test_mm_small_n_dim(self): rtol=1e-2, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_small_k_dim(self): """Test torch.mm with K=4 smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -913,6 +940,7 @@ def test_mm_small_k_dim(self): mm_func=lambda acc, a, b: acc + torch.mm(a, b), ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_multiple_small_dims(self): """Test torch.mm with multiple dims smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -924,40 +952,47 @@ def test_mm_multiple_small_dims(self): check_matmul_cast_pattern=True, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_m_1(self): """Test torch.mm with M=1 created through reshape.""" self._test_reshape_m_1( lambda acc, a, b: acc + torch.mm(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_n_1(self): """Test torch.mm with N=1 created through reshape.""" self._test_reshape_n_1( lambda acc, a, b: acc + torch.mm(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_k_1(self): """Test torch.mm with K=1 created through reshape.""" self._test_reshape_k_1(lambda acc, a, b: acc + torch.mm(a, b)) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_k_2(self): """Test torch.mm with K=2 created through reshape.""" self._test_reshape_k_2( lambda acc, a, b: acc + torch.mm(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_m_2(self): """Test torch.mm with M=2 created through reshape.""" self._test_reshape_m_2( lambda acc, a, b: acc + torch.mm(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_mm_reshape_n_2(self): """Test torch.mm with N=2 created through reshape.""" self._test_reshape_n_2( lambda acc, a, b: acc + torch.mm(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_small_m_dim(self): """Test torch.matmul with M=2 smaller than the minimum of 16 for tl.dot.""" # Allow slightly larger absolute error for small-dim tiles @@ -970,6 +1005,7 @@ def test_matmul_small_m_dim(self): rtol=1e-2, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_small_n_dim(self): """Test torch.matmul with N=3 smaller than the minimum of 16 for tl.dot.""" # Allow slightly larger absolute error for small-dim tiles @@ -982,6 +1018,7 @@ def test_matmul_small_n_dim(self): rtol=1e-2, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_small_k_dim(self): """Test torch.matmul with K=4 smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -991,6 +1028,7 @@ def test_matmul_small_k_dim(self): mm_func=lambda acc, a, b: acc + torch.matmul(a, b), ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_multiple_small_dims(self): """Test torch.matmul with multiple dims smaller than the minimum of 16 for tl.dot.""" self._test_small_dims( @@ -1002,34 +1040,40 @@ def test_matmul_multiple_small_dims(self): check_matmul_cast_pattern=True, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_m_1(self): """Test torch.matmul with M=1 created through reshape.""" self._test_reshape_m_1( lambda acc, a, b: acc + torch.matmul(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_n_1(self): """Test torch.matmul with N=1 created through reshape.""" self._test_reshape_n_1( lambda acc, a, b: acc + torch.matmul(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_k_1(self): """Test torch.matmul with K=1 created through reshape.""" self._test_reshape_k_1(lambda acc, a, b: acc + torch.matmul(a, b)) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_k_2(self): """Test torch.matmul with K=2 created through reshape.""" self._test_reshape_k_2( lambda acc, a, b: acc + torch.matmul(a, b), rtol=1e-2, atol=5e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_m_2(self): """Test torch.matmul with M=2 created through reshape.""" self._test_reshape_m_2( lambda acc, a, b: acc + torch.matmul(a, b), rtol=1e-2, atol=6.3e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_reshape_n_2(self): """Test torch.matmul with N=2 created through reshape.""" self._test_reshape_n_2( diff --git a/test/test_examples.py b/test/test_examples.py index 75873b26c..569c3a650 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -17,6 +17,7 @@ from helion._testing import import_path from helion._testing import skipIfA10G from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager from helion._testing import skipIfRocm from helion._testing import skipIfXPU @@ -27,6 +28,7 @@ @skipIfCpu("needs to be debugged") class TestExamples(RefEagerTestBase, TestCase): + @skipIfMTIA("Not supported on MTIA yet.") def test_add(self): args = ( torch.randn([512, 512], device=DEVICE, dtype=torch.float32), @@ -38,6 +40,7 @@ def test_add(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -53,6 +56,7 @@ def test_matmul(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_bwd(self): """Test backward pass for matmul computation.""" # Create tensors with requires_grad=True like rms_norm_bwd test @@ -89,6 +93,7 @@ def test_matmul_bwd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_addmm_bwd(self): """Test backward pass for addmm computation.""" # Create tensors with requires_grad=True following the matmul_bwd pattern @@ -179,6 +184,7 @@ def test_matmul_layernorm_dynamic_shapes(self): version.parse(torch.__version__.split("+")[0]) < version.parse("2.8"), "Requires torch 2.8+", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_bmm(self): args = ( torch.randn([16, 512, 768], device=DEVICE, dtype=torch.float16), @@ -364,6 +370,7 @@ def test_softmax_two_pass_block_ptr(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_cross_entropy(self): n, v = 128, 1000 args = ( @@ -445,6 +452,7 @@ def test_low_mem_dropout(self): @skipIfRocm("precision differences with bf16xint16 operations on rocm") @skipIfXPU("precision differences with bf16xint16 operations on xpu") + @skipIfMTIA("Not supported on MTIA yet.") def test_bf16xint16(self): from examples.bf16xint16_gemm import reference_bf16xint16_pytorch @@ -524,6 +532,7 @@ def test_swiglu_bwd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_rms_norm_bwd(self): """Test backward pass for rms norm weight gradient.""" batch_size, dim = 32, 64 @@ -604,6 +613,7 @@ def test_embedding_block_ptr(self): ) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_pointer(self): args = ( torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE), @@ -622,6 +632,7 @@ def test_attention_pointer(self): @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) @skipIfXPU("failure on XPU") + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_block_pointer(self): args = ( torch.randn(2, 32, 1024, 64, dtype=torch.float16, device=DEVICE), @@ -639,6 +650,7 @@ def test_attention_block_pointer(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_dynamic(self): args = ( torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE), @@ -655,6 +667,7 @@ def test_attention_dynamic(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_concat(self): args = ( torch.randn(512, 500, device=DEVICE), @@ -670,6 +683,7 @@ def test_concat(self): ) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @skipIfMTIA("Not supported on MTIA yet.") def test_concat_block_ptr(self): args = ( torch.randn(222, 100, device=DEVICE), @@ -686,6 +700,7 @@ def test_concat_block_ptr(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_dense_add(self): mod = import_path(EXAMPLES_DIR / "jagged_dense_add.py") args = ( @@ -701,6 +716,7 @@ def test_jagged_dense_add(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_dense_bmm(self): mod = import_path(EXAMPLES_DIR / "jagged_dense_bmm.py") seq_offsets, jagged, dense, bias = mod.random_input( @@ -716,6 +732,7 @@ def test_jagged_dense_bmm(self): ) @skipIfRefEager("Test has skip_accuracy=True and doesn't call assert_close") + @skipIfMTIA("Not supported on MTIA yet.") def test_moe_matmul_ogs(self): mod = import_path(EXAMPLES_DIR / "moe_matmul_ogs.py") @@ -742,6 +759,7 @@ def test_moe_matmul_ogs(self): ) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_split_k(self): args = ( torch.randn(64, 1024, device=DEVICE), @@ -758,6 +776,7 @@ def test_matmul_split_k(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_sum(self): args = (torch.randn([512, 512], device=DEVICE, dtype=torch.float32),) self.assertExpectedJournal( @@ -771,6 +790,7 @@ def test_sum(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_mean(self): num_rows, max_cols = 32, 64 M = 8 # number of features @@ -806,6 +826,7 @@ def test_jagged_mean(self): @skipIfRefEager( "torch._higher_order_ops.associative_scan with tuple arg is not supported by ref eager mode yet" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_segment_reduction(self): num_nodes = 100 num_edges = 1000 @@ -833,6 +854,7 @@ def test_segment_reduction(self): @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) @skipIfXPU("failure on XPU") + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_persistent_interleaved_l2_grouping(self): """Test attention with persistent interleaved execution and L2 grouping for optimal performance.""" args = ( @@ -945,6 +967,7 @@ def test_layernorm_no_bias(self): ) @skipIfA10G("accuracy check fails on A10G GPUs") + @skipIfMTIA("Not supported on MTIA yet.") def test_layernorm_bwd(self): """Test combined backward pass for layer norm with bias, including regression coverage.""" @@ -1063,6 +1086,7 @@ def test_layernorm_without_bias(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_softmax(self): num_rows, max_cols = 128, 64 M = 8 # number of features @@ -1153,6 +1177,7 @@ def test_jagged_hstu_attn(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_grouped_gemm_jagged(self): # Build small jagged grouped GEMM inputs torch.manual_seed(0) @@ -1187,6 +1212,7 @@ def test_grouped_gemm_jagged(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_grouped_gemm_jagged_persistent(self): # Build small jagged grouped GEMM inputs torch.manual_seed(0) @@ -1282,6 +1308,7 @@ def test_swiglu(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jsd(self): args = ( torch.randn( @@ -1308,6 +1335,7 @@ def test_jsd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_kl_div(self): args = ( torch.randn( @@ -1332,6 +1360,7 @@ def test_kl_div(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_gather_gemv(self): args = ( torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32), @@ -1352,6 +1381,7 @@ def expected(w, idx, x): num_stages=1, ) + @skipIfMTIA("Not supported on MTIA yet.") def test_int4_gemm(self): # Matrix dimensions M, K, N = 256, 512, 256 @@ -1387,6 +1417,7 @@ def test_int4_gemm(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_sum(self): num_rows, max_cols = 128, 64 M = 8 # number of features @@ -1415,6 +1446,7 @@ def test_jagged_sum(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_fused_linear_jsd(self): beta = 0.5 ignore_index = 1 @@ -1452,6 +1484,7 @@ def test_fused_linear_jsd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_jagged_layer_norm(self): num_rows, max_cols = 128, 64 M = 8 # number of features @@ -1520,6 +1553,7 @@ def test_exp_bwd(self): @skipIfRocm("failure on rocm") @skipIfA10G("failure on a10g") + @skipIfMTIA("Not supported on MTIA yet.") def test_squeeze_and_excitation_net_fwd(self): m, n, k = 1024, 1024, 1024 x = torch.randn([m, n], device=DEVICE, dtype=torch.float16) @@ -1546,6 +1580,7 @@ def test_squeeze_and_excitation_net_fwd(self): @skipIfRocm("failure on rocm") @skipIfA10G("failure on a10g") + @skipIfMTIA("Not supported on MTIA yet.") def test_squeeze_and_excitation_net_bwd_dx(self): m, n, k = 256, 256, 256 x = torch.randn([m, n], device=DEVICE, dtype=torch.float16) @@ -1589,6 +1624,7 @@ def test_squeeze_and_excitation_net_bwd_dx(self): @skipIfRocm("failure on rocm") @skipIfA10G("failure on a10g") + @skipIfMTIA("Not supported on MTIA yet.") def test_squeeze_and_excitation_net_bwd_da(self): m, n, k = 256, 256, 256 x = torch.randn([m, n], device=DEVICE, dtype=torch.float16) @@ -1632,6 +1668,7 @@ def test_squeeze_and_excitation_net_bwd_da(self): @skipIfRocm("failure on rocm") @skipIfA10G("failure on a10g") + @skipIfMTIA("Not supported on MTIA yet.") def test_squeeze_and_excitation_net_bwd_db(self): m, n, k = 256, 256, 256 x = torch.randn([m, n], device=DEVICE, dtype=torch.float16) @@ -1674,6 +1711,7 @@ def test_squeeze_and_excitation_net_bwd_db(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_grpo_loss_fwd(self): """Test forward pass for GRPO loss.""" B, L, V = 4, 512, 2048 @@ -1740,6 +1778,7 @@ def test_grpo_loss_fwd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_grpo_loss_bwd(self): """Test backward pass for GRPO loss.""" B, L, V = 2, 64, 128 @@ -1835,6 +1874,7 @@ def test_grpo_loss_bwd(self): ) ) + @skipIfMTIA("Not supported on MTIA yet.") def test_gdn_fwd_h(self): """Test gated delta net forward h kernel.""" import math diff --git a/test/test_grid.py b/test/test_grid.py index 3006db8fb..ac6c7f7e0 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -12,6 +12,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfMTIA import helion.language as hl @@ -38,6 +39,7 @@ class TestGrid(RefEagerTestBase, TestCase): supports_tensor_descriptor(), "Tensor descriptor support is required" ) @patch.object(_compat, "_min_dot_size", lambda *args: (16, 16, 16)) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_1d(self): @helion.kernel(static_shapes=True) def grid_1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -84,6 +86,7 @@ def grid_1d_pytorch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_2d_idx_list(self): @helion.kernel(static_shapes=True) def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -126,6 +129,7 @@ def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_2d_pytorch(args[0], args[1])) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_2d_idx_nested(self): @helion.kernel(static_shapes=True) def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -159,6 +163,7 @@ def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_2d_pytorch(args[0], args[1])) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_begin_end(self): @helion.kernel(autotune_effort="none") def grid_begin_end(x: torch.Tensor) -> torch.Tensor: @@ -180,6 +185,7 @@ def grid_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_begin_end_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_begin_end_step(self): @helion.kernel(autotune_effort="none") def grid_begin_end_step(x: torch.Tensor) -> torch.Tensor: @@ -201,6 +207,7 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_begin_end_step_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_end_step_kwarg(self): @helion.kernel(autotune_effort="none") def grid_end_step_kwarg(x: torch.Tensor) -> torch.Tensor: @@ -222,6 +229,7 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_end_step_kwarg_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_multidim_begin_end(self): @helion.kernel(autotune_effort="none") def grid_multidim_begin_end(x: torch.Tensor) -> torch.Tensor: @@ -246,6 +254,7 @@ def grid_multidim_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_multidim_begin_end_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_grid_multidim_begin_end_step(self): @helion.kernel(autotune_effort="none") def grid_multidim_begin_end_step(x: torch.Tensor) -> torch.Tensor: @@ -270,6 +279,7 @@ def grid_multidim_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, grid_multidim_begin_end_step_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_tile_begin_end(self): @helion.kernel(autotune_effort="none") def tile_begin_end(x: torch.Tensor) -> torch.Tensor: @@ -290,6 +300,7 @@ def tile_begin_end_pytorch(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, tile_begin_end_pytorch(x)) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_range_as_grid_basic(self): """Test that range() works as an alias for hl.grid() in device code.""" @@ -310,6 +321,7 @@ def range_kernel(x: torch.Tensor) -> torch.Tensor: code, result = code_and_output(range_kernel, (x,)) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_range_with_begin_end(self): """Test that range(begin, end) works as alias for hl.grid(begin, end).""" @@ -353,6 +365,7 @@ def range_step_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_range_with_tensor_size(self): """Test that range(tensor.size(dim)) works with dynamic tensor dimensions.""" diff --git a/test/test_indexing.py b/test/test_indexing.py index 88ff4fee8..7d8587486 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -16,6 +16,7 @@ from helion._testing import code_and_output from helion._testing import skipIfCpu from helion._testing import skipIfLowVRAM +from helion._testing import skipIfMTIA from helion._testing import skipIfNormalMode from helion._testing import skipIfRefEager from helion._testing import skipIfRocm @@ -210,6 +211,7 @@ def pairwise_add(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_pairwise_add_commuted_and_multi_offset(self): @helion.kernel() def pairwise_add_variants(x: torch.Tensor) -> torch.Tensor: @@ -915,6 +917,7 @@ def test_broadcasting_block_ptr_indexing(self): get_tensor_descriptor_fn_name() == "tl._experimental_make_tensor_descriptor", "LLVM ERROR: Illegal shared layout", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_broadcasting_tensor_descriptor_indexing(self): x = torch.randn([16, 24, 32], device=DEVICE) bias1 = torch.randn([1, 24, 32], device=DEVICE) @@ -934,6 +937,7 @@ def test_broadcasting_tensor_descriptor_indexing(self): get_tensor_descriptor_fn_name() != "tl._experimental_make_tensor_descriptor", "Not using experimental tensor descriptor", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduction_tensor_descriptor_indexing_block_size(self): x = torch.randn([64, 64], dtype=torch.float32, device=DEVICE) @@ -955,6 +959,7 @@ def test_reduction_tensor_descriptor_indexing_block_size(self): get_tensor_descriptor_fn_name() != "tl._experimental_make_tensor_descriptor", "Not using experimental tensor descriptor", ) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduction_tensor_descriptor_indexing_reduction_loop(self): x = torch.randn([64, 256], dtype=torch.float16, device=DEVICE) @@ -1141,6 +1146,7 @@ def kernel(buf: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor: @skipIfRocm("failure on rocm") @unittest.skip("takes 5+ minutes to run") + @skipIfMTIA("Not supported on MTIA yet.") def test_1d_indexed_value_from_slice(self): """buf2[i] = buf[:] - Assign slice to indexed value""" @@ -1249,6 +1255,7 @@ def kernel( torch.testing.assert_close(dst2_result, expected_dst2) @skipIfNormalMode("InternalError: Negative indexes") + @skipIfMTIA("Not supported on MTIA yet.") def test_negative_indexing(self): """Test both setter from scalar and getter for [-1]""" @@ -1277,6 +1284,7 @@ def kernel( @skipIfNormalMode( "RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_ellipsis_indexing(self): """Test both setter from scalar and getter for [..., i]""" @@ -1305,6 +1313,7 @@ def kernel( @skipIfNormalMode( "RankMismatch: Cannot assign a tensor of rank 2 to a buffer of rank 3" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_multi_dim_slice(self): """Test both setter from scalar and getter for [:, :, i]""" @@ -1333,6 +1342,7 @@ def kernel( @skipIfNormalMode( "RankMismatch: Expected ndim=2, but got ndim=1 - tensor value assignment shape mismatch" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_tensor_value(self): """Test both setter from tensor value and getter for [i]""" @@ -1403,6 +1413,7 @@ def kernel( torch.testing.assert_close(dst_result, expected_dst) @skipIfNormalMode("InternalError: Unexpected type ") + @skipIfMTIA("Not supported on MTIA yet.") def test_range_slice(self): """Test both setter from scalar and getter for [10:20]""" @@ -1431,6 +1442,7 @@ def kernel( @skipIfNormalMode( "InternalError: AssertionError in type_propagation.py - slice indexing error" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_range_slice_dynamic(self): """Test both [i:i+1] = scalar and [i] = [i:i+1] patterns""" @@ -1504,6 +1516,7 @@ def tile_offset_kernel(x: torch.Tensor) -> torch.Tensor: self.assertExpectedJournal(code) @unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported") + @skipIfMTIA("Not supported on MTIA yet.") def test_tile_with_offset_tensor_descriptor(self): """Test Tile+offset with tensor_descriptor indexing for 2D tensors""" diff --git a/test/test_inline_asm_elementwise.py b/test/test_inline_asm_elementwise.py index d21be7afb..28ae43fbc 100644 --- a/test/test_inline_asm_elementwise.py +++ b/test/test_inline_asm_elementwise.py @@ -11,6 +11,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRocm import helion.language as hl @@ -20,6 +21,7 @@ class TestInlineAsmElementwise(RefEagerTestDisabled, TestCase): DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) @skipIfRocm("only works on cuda") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_simple(self): """Test basic inline_asm_elementwise with simple assembly""" @@ -49,6 +51,7 @@ def kernel_simple_asm(x: torch.Tensor) -> torch.Tensor: DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) @skipIfRocm("only works on cuda") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_shift_operation(self): """Test inline_asm_elementwise with shift operation (similar to Triton test)""" @@ -87,6 +90,7 @@ def kernel_shift_asm(x: torch.Tensor, y: torch.Tensor, n: int) -> torch.Tensor: DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) @skipIfRocm("only works on cuda") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_multiple_outputs(self): """Test inline_asm_elementwise with multiple outputs""" @@ -136,6 +140,7 @@ def kernel_multiple_outputs( DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) @skipIfRocm("only works on cuda") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_packed(self): """Test inline_asm_elementwise with pack > 1""" @@ -193,6 +198,7 @@ def kernel_invalid_asm(x: torch.Tensor) -> torch.Tensor: DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) @skipIfRocm("only works on cuda") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_empty_args(self): """Test inline_asm_elementwise with empty args (should work like Triton)""" @@ -223,6 +229,7 @@ def kernel_empty_args(x: torch.Tensor) -> torch.Tensor: @skipIfRocm("only works on cuda") @skipIfCpu("RuntimeError: failed to translate module to LLVM IR") + @skipIfMTIA("Not supported on MTIA yet.") def test_inline_asm_basic_compilation(self): """Test that inline_asm_elementwise compiles without errors (no CUDA requirement)""" diff --git a/test/test_logging.py b/test/test_logging.py index a8996fb05..b03e41be8 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -13,6 +13,7 @@ from helion._testing import DEVICE from helion._testing import RefEagerTestDisabled from helion._testing import TestCase +from helion._testing import skipIfMTIA import helion.language as hl @@ -70,10 +71,11 @@ def test_log_set(self): logging.INFO, ) self.assertEqual( - helion._logging._internal._LOG_REGISTRY.log_levels["fuzz.baz"], - logging.DEBUG, - ) + helion._logging._internal._LOG_REGISTRY.log_levels["fuzz.baz"], + logging.DEBUG, + ) + @skipIfMTIA("Not supported on MTIA yet.") def test_kernel_log(self): @helion.kernel( config=helion.Config( diff --git a/test/test_masking.py b/test/test_masking.py index 0290c3d61..ba150ce2f 100644 --- a/test/test_masking.py +++ b/test/test_masking.py @@ -12,11 +12,13 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager import helion.language as hl class TestMasking(RefEagerTestBase, TestCase): + @skipIfMTIA("Not supported on MTIA yet.") def test_mask_dot(self): @helion.kernel(config={"block_sizes": [[32, 32], 32]}, dot_precision="ieee") def add1mm(x, y): @@ -44,6 +46,7 @@ def add1mm(x, y): ) @skipIfCpu("AssertionError: Tensor-likes are not close!") + @skipIfMTIA("Not supported on MTIA yet.") def test_no_mask_views0(self): @helion.kernel(config={"block_sizes": [32]}) def fn(x): @@ -62,6 +65,7 @@ def fn(x): self.assertNotIn("tl.where", code) @skipIfCpu("AssertionError: Tensor-likes are not close!") + @skipIfMTIA("Not supported on MTIA yet.") def test_no_mask_views1(self): @helion.kernel(config={"block_sizes": [32]}) def fn(x): @@ -79,6 +83,7 @@ def fn(x): torch.testing.assert_close(result, args[0].sum(dim=1)) self.assertNotIn("tl.where", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_no_mask_full0(self): @helion.kernel(config={"block_sizes": [32]}) def fn(x): @@ -116,6 +121,7 @@ def fn(x): torch.testing.assert_close(result, torch.zeros_like(args[0]).sum(dim=1)) self.assertNotIn("tl.where", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_mask_offset(self): @helion.kernel(config={"block_sizes": [32]}) def fn(x): @@ -134,6 +140,7 @@ def fn(x): self.assertIn("tl.where", code) @skipIfCpu("AssertionError: Tensor-likes are not close!") + @skipIfMTIA("Not supported on MTIA yet.") def test_no_mask_inductor_ops(self): @helion.kernel(config={"block_sizes": [32]}) def fn(x): @@ -184,6 +191,7 @@ def fn(x): @skipIfRefEager( "Test is block size dependent which is not supported in ref eager mode" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_tile_index_does_not_mask(self): @helion.kernel(config={"block_sizes": [32, 32], "indexing": "block_ptr"}) def fn(x): diff --git a/test/test_matmul.py b/test/test_matmul.py index fa76013e4..b560bb64d 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -18,6 +18,7 @@ from helion._testing import skipIfCpu from helion._testing import skipIfRefEager from helion._testing import skipIfRocm +from helion._testing import skipIfMTIA import helion.language as hl torch.backends.cuda.matmul.fp32_precision = "tf32" @@ -73,6 +74,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: class TestMatmul(RefEagerTestBase, TestCase): + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul0(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -87,6 +89,7 @@ def test_matmul0(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul1(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -101,6 +104,7 @@ def test_matmul1(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul3(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -116,6 +120,7 @@ def test_matmul3(self): self.assertExpectedJournal(code) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_block_ptr(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -147,6 +152,7 @@ def test_matmul_tensor_descriptor(self): code = examples_matmul.bind(args).to_triton_code(config) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_static_shapes0(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -162,6 +168,7 @@ def test_matmul_static_shapes0(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_static_shapes1(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), @@ -176,6 +183,7 @@ def test_matmul_static_shapes1(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_static_shapes2(self): args = ( torch.randn([128, 127], device=DEVICE, dtype=torch.float32), @@ -190,6 +198,7 @@ def test_matmul_static_shapes2(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_static_shapes3(self): args = ( torch.randn([127, 128], device=DEVICE, dtype=torch.float32), @@ -205,6 +214,7 @@ def test_matmul_static_shapes3(self): self.assertExpectedJournal(code) @skipIfCpu("fails on Triton CPU backend") + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_packed_int4_block_size_constexpr(self): torch.manual_seed(0) M = N = K = 32 @@ -253,6 +263,7 @@ def matmul_bf16_packed_int4( self.assertTrue(torch.isfinite(C).all()) self.assertFalse(torch.allclose(C, torch.zeros_like(C))) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_split_k(self): @helion.kernel(dot_precision="ieee") def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -280,6 +291,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @skipIfRocm("ROCm triton error in TritonAMDGPUBlockPingpong") @skipIfRefEager("config_spec is not supported in ref eager mode") + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_config_reuse_with_unit_dim(self): torch.manual_seed(0) big_args = ( @@ -305,6 +317,7 @@ def test_matmul_config_reuse_with_unit_dim(self): expected = small_args[0] @ small_args[1] torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_packed_rhs(self): @helion.kernel(static_shapes=False) def matmul_with_packed_b( diff --git a/test/test_misc.py b/test/test_misc.py index e7c533c59..69b97b90f 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -28,6 +28,7 @@ from helion._testing import code_and_output from helion._testing import import_path from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfPyTorchBaseVerLessThan from helion._testing import skipIfRefEager import helion.language as hl @@ -54,6 +55,7 @@ def kernel_with_duplicate_refs(x: torch.Tensor) -> torch.Tensor: code, result = code_and_output(kernel_with_duplicate_refs, (x,)) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_min_hoist(self): """Test case to reproduce issue #1155: offsets are hoisted out of loops""" @@ -307,6 +309,7 @@ def kernel( torch.testing.assert_close(out, expected, atol=1e-2, rtol=1e-2) @skipIfRefEager("Config tests not applicable in ref eager mode") + @skipIfMTIA("Not supported on MTIA yet.") def test_config_flatten_issue(self): @helion.kernel(autotune_effort="none") def test_tile_begin(x: torch.Tensor) -> torch.Tensor: @@ -449,6 +452,7 @@ def kernel_with_scalar_item( torch.testing.assert_close(result2, x + 10) @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) + @skipIfMTIA("Not supported on MTIA yet.") def test_tuple_literal_subscript(self): @helion.kernel def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor: @@ -506,6 +510,7 @@ def tuple_literal_index_kernel(inp_tuple) -> torch.Tensor: torch.testing.assert_close(result, (inp_tuple[0] + inp_tuple[1][:, :30]) * 3) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_tuple_unpack(self): @helion.kernel def tuple_unpack_kernel(inp_tuple) -> torch.Tensor: @@ -525,6 +530,7 @@ def tuple_unpack_kernel(inp_tuple) -> torch.Tensor: self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_propagate_tile(self): @helion.kernel def copy_kernel(a: torch.Tensor) -> torch.Tensor: @@ -541,6 +547,7 @@ def copy_kernel(a: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, args[0]) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") @parametrize("static_shapes", (True, False)) def test_sequence_assert(self, static_shapes): @helion.kernel(static_shapes=static_shapes) @@ -558,6 +565,7 @@ def kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: self.assertExpectedJournal(code) @skipIfRefEager("no code execution") + @skipIfMTIA("Not supported on MTIA yet.") def test_triton_repro_add(self): mod = import_path(EXAMPLES_DIR / "add.py") a = torch.randn(16, 1, device=DEVICE) @@ -581,6 +589,7 @@ def test_triton_repro_add(self): self.assertEqual(result.returncode, 0, msg=f"stderr:\n{result.stderr}") self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") @skipIfRefEager("no code execution") @parametrize("static_shapes", (True, False)) def test_triton_repro_custom(self, static_shapes): @@ -618,6 +627,7 @@ def kernel( self.assertExpectedJournal(code) @skipIfRefEager("no code execution") + @skipIfMTIA("Not supported on MTIA yet.") def test_repro_parseable(self): @helion.kernel def kernel(fn, t: torch.Tensor): @@ -633,6 +643,7 @@ def kernel(fn, t: torch.Tensor): ast.parse(code) @skipIfPyTorchBaseVerLessThan("2.10") + @skipIfMTIA("Not supported on MTIA yet.") def test_builtin_min(self) -> None: @helion.kernel(autotune_effort="none") def helion_min_kernel(x_c): @@ -666,6 +677,7 @@ def ref_min(x): torch.testing.assert_close(helion_out, ref_out, rtol=1e-3, atol=1e-3) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_builtin_max(self) -> None: @helion.kernel(autotune_effort="none") def helion_max_kernel(x_c): diff --git a/test/test_persistent_kernels.py b/test/test_persistent_kernels.py index 3c9be9577..e7914ce8b 100644 --- a/test/test_persistent_kernels.py +++ b/test/test_persistent_kernels.py @@ -12,6 +12,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager import helion.language as hl @@ -59,6 +60,7 @@ def add1_kernel(x: torch.Tensor) -> torch.Tensor: class TestPersistentKernels(RefEagerTestBase, TestCase): """Test persistent kernel codegen with different PID strategies.""" + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_blocked_simple_add(self): """Test persistent blocked kernel with simple addition.""" @@ -132,6 +134,7 @@ def test_persistent_blocked_matmul(self): self.assertIn("for virtual_pid in tl.range", code_persistent) self.assertIn("virtual_pid", code_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_interleaved_matmul(self): """Test persistent interleaved kernel with matrix multiplication.""" @@ -169,6 +172,7 @@ def test_persistent_interleaved_matmul(self): self.assertIn("for virtual_pid in tl.range", code_persistent) self.assertIn("virtual_pid", code_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_blocked_3d(self): """Test persistent blocked kernel with 3D tensor.""" @@ -199,6 +203,7 @@ def test_persistent_blocked_3d(self): self.assertIn("num_blocks_0", code_persistent) self.assertIn("num_blocks_1", code_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_interleaved_3d(self): """Test persistent interleaved kernel with 3D tensor.""" @@ -235,6 +240,7 @@ def test_persistent_interleaved_3d(self): self.assertIn("num_blocks_0", code_persistent) self.assertIn("num_blocks_1", code_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_flat_vs_persistent_blocked_equivalence(self): """Test that flat and persistent_blocked produce same results.""" @@ -254,6 +260,7 @@ def test_flat_vs_persistent_blocked_equivalence(self): # Should produce identical results torch.testing.assert_close(result_flat, result_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_xyz_vs_persistent_interleaved_equivalence(self): """Test that xyz and persistent_interleaved produce same results.""" @@ -273,6 +280,7 @@ def test_xyz_vs_persistent_interleaved_equivalence(self): # Should produce identical results torch.testing.assert_close(result_xyz, result_persistent) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_kernels_with_shared_program_id(self): """Test persistent kernels with multiple top-level for loops to trigger ForEachProgramID. @@ -345,6 +353,7 @@ def multi_loop_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertIn("pid_shared", code_interleaved) self.assertIn("if pid_shared <", code_interleaved) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_shared_vs_flat_shared_equivalence(self): """Test that persistent+ForEachProgramID produces same results as flat+ForEachProgramID.""" @@ -397,6 +406,7 @@ def shared_loops_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(results_flat[0], expected1) torch.testing.assert_close(results_flat[1], expected2) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_kernels_complex_shared_scenario(self): """Test persistent kernels with a more complex ForEachProgramID scenario.""" @@ -459,6 +469,7 @@ def complex_shared_kernel( self.assertIn("pid_shared", code_interleaved) self.assertIn("if pid_shared <", code_interleaved) + @skipIfMTIA("Not supported on MTIA yet.") def test_persistent_blocked_with_l2_grouping(self): """Test persistent blocked kernels work with L2 grouping.""" @@ -503,6 +514,7 @@ def test_persistent_blocked_with_l2_grouping(self): self.assertIn("_NUM_SM: tl.constexpr", code_persistent_l2) self.assertIn("helion.runtime.get_num_sm(", code_persistent_l2) + @skipIfMTIA("Not supported on MTIA yet.") def test_shared_program_id_with_persistent_basic_functionality(self): """Test that ForEachProgramID + persistent kernels generate correct code structure.""" @@ -562,6 +574,7 @@ def multi_add_kernel( self.assertIn("_NUM_SM: tl.constexpr", code_persistent_shared) self.assertIn("helion.runtime.get_num_sm(", code_persistent_shared) + @skipIfMTIA("Not supported on MTIA yet.") def test_simple_persistent_kernels_work(self): """Test that simple persistent kernels compile and run correctly.""" @@ -625,6 +638,7 @@ def reserved_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result_reserved, x) self.assertIn("reserved_sms=3", code_reserved) + @skipIfMTIA("Not supported on MTIA yet.") def test_multi_loop_persistent_with_shared_program_id(self): """Test that multi-loop persistent kernels with ForEachProgramID work correctly. @@ -1208,6 +1222,7 @@ def data_dependent_grid_kernel( @skipIfCpu("Persistent kernels not supported on CPU") @skipIfRefEager("Code pattern checking not applicable in ref eager mode") + @skipIfMTIA("Not supported on MTIA yet.") def test_data_dependent_tile_bounds_codegen(self): """Test that data-dependent tile bounds work with persistent kernels. diff --git a/test/test_print.py b/test/test_print.py index 52a36cdea..970d8f613 100644 --- a/test/test_print.py +++ b/test/test_print.py @@ -15,6 +15,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRocm import helion.language as hl @@ -112,6 +113,7 @@ def run_test_with_and_without_triton_interpret_envvar(self, test_func): os.environ["TRITON_INTERPRET"] = original_env @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_basic_print(self): """Test basic print with prefix and tensor values""" @@ -149,6 +151,7 @@ def print_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_multiple_tensors(self): """Test print with multiple tensor arguments""" @@ -256,6 +259,7 @@ def print_shape_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_prefix_only(self): def run_test(interpret_mode): @helion.kernel @@ -289,6 +293,7 @@ def print_message_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_in_nested_loops(self): def run_test(interpret_mode): @helion.kernel @@ -352,6 +357,7 @@ def print_nested_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) + @skipIfMTIA("Not supported on MTIA yet.") def test_print_outside_tile_loops(self): """Test print statements outside tile loops""" @@ -382,6 +388,7 @@ def print_outside_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_with_conditional(self): """Test print with conditional statements""" @@ -442,6 +449,7 @@ def print_conditional_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_computed_values(self): """Test print with computed/derived values""" @@ -535,6 +543,7 @@ def print_reduction_kernel(x: torch.Tensor) -> torch.Tensor: self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_multiple_data_types(self): """Test print with different tensor data types""" @@ -593,6 +602,7 @@ def print_dtypes_kernel( self.run_test_with_and_without_triton_interpret_envvar(run_test) @skipIfRocm("failure on rocm") + @skipIfMTIA("Not supported on MTIA yet.") def test_print_with_starred_args(self): """Test print with starred/unpacked arguments""" diff --git a/test/test_reduce.py b/test/test_reduce.py index 224d1cf6f..4101a495f 100644 --- a/test/test_reduce.py +++ b/test/test_reduce.py @@ -9,6 +9,7 @@ from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfMTIA import helion.language as hl @@ -80,6 +81,7 @@ def jit_add_combine_fn(x, y): class TestReduce(RefEagerTestBase, TestCase): + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_basic_sum(self): """Test basic reduce functionality with sum reduction along a dimension.""" @@ -109,6 +111,7 @@ def test_reduce_kernel(x: torch.Tensor) -> torch.Tensor: self.assertIn("tl.reduce", code) self.assertIn("add_combine_fn_", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_max(self): """Test reduce with maximum operation.""" @@ -134,6 +137,7 @@ def test_reduce_max_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([4.0, 8.0, 12.0], device=DEVICE) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_with_keep_dims(self): """Test reduce with keep_dims=True.""" @@ -164,12 +168,14 @@ def test_reduce_keep_dims_kernel(x: torch.Tensor) -> torch.Tensor: # Check that keep_dims=True is in the generated code self.assertIn("keep_dims=True", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_all_dims(self): """Test reduce with dim=None (reduce all dimensions) - SKIP for now.""" # Skip this test for now - dim=None has complex implementation issues # with symbolic shapes that require more work to fix properly self.skipTest("dim=None reduction requires more complex implementation") + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_min(self): """Test reduce with minimum operation.""" @@ -195,6 +201,7 @@ def test_reduce_min_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([1.0, 5.0, 9.0], device=DEVICE) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_product(self): """Test reduce with multiplication operation using other=1.""" @@ -220,6 +227,7 @@ def test_reduce_product_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([6.0, 24.0, 5.0], device=DEVICE) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_jit_combine_fn(self): """Test reduce with @helion.kernel decorated combine function.""" @@ -245,6 +253,7 @@ def test_reduce_jit_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([10.0, 26.0], device=DEVICE) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_tuple_input(self): """Test reduce with tuple input.""" @@ -310,6 +319,7 @@ def test_reduce_int_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([10, 26], device=DEVICE, dtype=torch.int64) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_tuple_unpacking_oneline(self): """Test tuple unpacking in one line: a, b = hl.reduce(...)""" @@ -372,6 +382,7 @@ def test_tuple_oneline_kernel( self.assertIn("tl.reduce", code) self.assertIn("argmax_combine_fn_", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_tuple_unpacking_twoline(self): """Test tuple unpacking in two lines: result = hl.reduce(...); a, b = result""" @@ -435,6 +446,7 @@ def test_tuple_twoline_kernel( self.assertIn("tl.reduce", code) self.assertIn("argmax_combine_fn_", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_argmax_negative_values(self): """Test argmax with all negative values using other=(-inf, 0).""" @@ -500,6 +512,7 @@ def test_argmax_negative_kernel( self.assertIn("tl.reduce", code) self.assertIn("argmax_combine_fn_", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_code_generation(self): """Test that reduce generates correct Triton code.""" @@ -527,6 +540,7 @@ def test_reduce_codegen_kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.tensor([6.0], device=DEVICE) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_tuple_unpacked_format(self): """Test reduce with tuple input using unpacked format combine function.""" @@ -570,6 +584,7 @@ def test_reduce_tuple_unpacked_kernel( torch.testing.assert_close(result_x, expected_x) torch.testing.assert_close(result_y, expected_y) + @skipIfMTIA("Not supported on MTIA yet.") def test_reduce_argmax_unpacked_format(self): """Test argmax with unpacked format combine function.""" diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py index 693b43d91..ddcbe3a29 100644 --- a/test/test_ref_eager.py +++ b/test/test_ref_eager.py @@ -11,6 +11,7 @@ from helion._testing import DEVICE from helion._testing import TestCase from helion._testing import assert_ref_eager_mode +from helion._testing import skipIfMTIA import helion.language as hl @@ -124,6 +125,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor: expected = torch.arange(8, device=DEVICE, dtype=torch.float32) torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_store_with_duplicate_indices_raises_error(self): """Test that hl.store with duplicate indices raises an error in ref mode.""" @@ -143,6 +145,7 @@ def kernel_with_dup_store( with self.assertRaises(helion.exc.DuplicateStoreIndicesError): kernel_with_dup_store(out, idx, val) + @skipIfMTIA("Not supported on MTIA yet.") def test_store_dtype_conversion(self): """Test that hl.store properly converts dtype in ref eager mode.""" @@ -174,6 +177,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor: result.to(torch.float32), x.to(torch.float32), atol=1e-2, rtol=1e-2 ) + @skipIfMTIA("Not supported on MTIA yet.") def test_load_2d_indexing_without_extra_mask(self): """Test that hl.load with two 1D tensor indices produces 2D output in ref eager mode.""" @@ -192,6 +196,7 @@ def kernel(mask: torch.Tensor) -> torch.Tensor: result = kernel(mask) torch.testing.assert_close(result, mask) + @skipIfMTIA("Not supported on MTIA yet.") def test_load_3d_indexing_without_extra_mask(self): """Test that hl.load with three 1D tensor indices produces 3D output in ref eager mode.""" diff --git a/test/test_register_tunable.py b/test/test_register_tunable.py index b6ea5f959..0bb69b8ee 100644 --- a/test/test_register_tunable.py +++ b/test/test_register_tunable.py @@ -12,6 +12,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRocm from helion.autotuner import EnumFragment from helion.autotuner import IntegerFragment @@ -92,6 +93,7 @@ def kernel_with_enum(x: torch.Tensor) -> torch.Tensor: expected = x * 2.0 torch.testing.assert_close(result, expected) + @skipIfMTIA("Not supported on MTIA yet.") def test_tensor_allocated_with_block_size(self): @helion.kernel() def fn(x: torch.Tensor): @@ -111,6 +113,7 @@ def fn(x: torch.Tensor): @patch.object(_compat, "_supports_tensor_descriptor", lambda: False) @skipIfRocm("failure on rocm") @skipIfCpu("Failed: Timeout (>10.0s) from pytest-timeout.") + @skipIfMTIA("Not supported on MTIA yet.") def test_matmul_split_k(self): """Test matmul_split_k kernel with register_tunable""" diff --git a/test/test_rng.py b/test/test_rng.py index 822f5e2c2..19647e9c2 100644 --- a/test/test_rng.py +++ b/test/test_rng.py @@ -11,6 +11,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA import helion.language as hl @@ -135,6 +136,7 @@ def rand_kernel_3d(x: torch.Tensor) -> torch.Tensor: _code2, output2 = code_and_output(rand_kernel_3d, (x,)) self.assertFalse(torch.allclose(output, output2)) + @skipIfMTIA("Not supported on MTIA yet.") def test_multiple_rng_ops(self): """Test multiple RNG operations: independence, reproducibility, mixed rand/randn.""" @@ -280,6 +282,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: # Different seeds should produce different outputs self.assertFalse(torch.allclose(output1, output2)) + @skipIfMTIA("Not supported on MTIA yet.") def test_randn_normal_distribution(self): """Test that torch.randn_like produces normal distribution (mean≈0, std≈1).""" @@ -315,6 +318,7 @@ def randn_kernel_tiled_2d(x: torch.Tensor) -> torch.Tensor: 0.63 < within_1_std < 0.73, f"Values within 1 std: {within_1_std}" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_randn_3d_tensor(self): """Test 3D randn with tiled operations.""" diff --git a/test/test_specialize.py b/test/test_specialize.py index b4224f4b2..82502c101 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -11,6 +11,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfRefEager from helion.exc import ShapeSpecializingAllocation import helion.language as hl @@ -20,6 +21,7 @@ class TestSpecialize(RefEagerTestBase, TestCase): maxDiff = 163842 + @skipIfMTIA("Not supported on MTIA yet.") def test_sqrt_does_not_specialize(self): @helion.kernel() def fn( @@ -36,6 +38,7 @@ def fn( torch.testing.assert_close(result, x / math.sqrt(x.size(-1))) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_specialize_host(self): @helion.kernel() def fn( @@ -89,6 +92,7 @@ def fn( self.assertEqual(len(fn.bind((x,)).config_spec.reduction_loops), 0) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_dynamic_size_block_non_power_of_two(self): @helion.kernel() def fn( @@ -114,6 +118,7 @@ def fn( ) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_dynamic_size_block_non_power_of_two_outplace(self): @helion.kernel() def fn( @@ -139,6 +144,7 @@ def fn( ) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_dynamic_size_block_non_power_of_two_swap_order(self): @helion.kernel() def fn( @@ -164,6 +170,7 @@ def fn( ) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_dynamic_size_block_non_power_of_two_double_acc(self): @helion.kernel() def fn( @@ -191,6 +198,7 @@ def fn( ) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_dynamic_size_block_non_power_of_two_matmul(self): @helion.kernel() def fn( @@ -286,6 +294,7 @@ def reduce_kernel( _test_with_factory(lambda x, s, **kw: hl.zeros([s], **kw), test_host=False) _test_with_factory(lambda x, s, **kw: hl.full([s], 1.0, **kw), test_host=False) + @skipIfMTIA("Not supported on MTIA yet.") def test_specialize_reduce(self): @helion.kernel() def fn( diff --git a/test/test_tensor_descriptor.py b/test/test_tensor_descriptor.py index 1cd372412..f9f1a056e 100644 --- a/test/test_tensor_descriptor.py +++ b/test/test_tensor_descriptor.py @@ -13,6 +13,7 @@ from helion._testing import TestCase from helion._testing import check_example from helion._testing import code_and_output +from helion._testing import skipIfMTIA import helion.language as hl @@ -20,6 +21,7 @@ class TestTensorDescriptor(RefEagerTestBase, TestCase): @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_permutation_when_stride_one_not_last(self): """Test that permutation is applied when stride==1 is not the last dimension.""" @@ -59,6 +61,7 @@ def kernel_with_permutation(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_no_permutation_when_stride_one_already_last(self): """Test that no permutation is applied when stride==1 is already last.""" @@ -95,6 +98,7 @@ def kernel_no_permutation(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_3d_tensor_permutation(self): """Test permutation with 3D tensor where stride==1 is in middle.""" @@ -130,6 +134,7 @@ def kernel_3d_permutation(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_matrix_transpose_case(self): """Test a common case: transposed matrix operations.""" @@ -166,6 +171,7 @@ def kernel_transpose_case(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_permutation_with_different_block_sizes(self): """Test that permutation works correctly with different block sizes.""" @@ -202,6 +208,7 @@ def kernel_different_blocks(x: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_multistage_range_tensor_descriptor(self): @helion.kernel( config=helion.Config( @@ -319,6 +326,7 @@ def jsd_forward_kernel( @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_tiny_matmul_tile_fallback(self) -> None: """Tensor descriptor indexing should be rejected when the tile is too small.""" @@ -384,6 +392,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_store_operation_permutation(self): """Test that store operations also handle permutation correctly.""" @@ -421,6 +430,7 @@ def kernel_store_permutation(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_tensor_descriptor(self): args = ( torch.randn(2, 32, 1024, 64, dtype=torch.float16, device=DEVICE), @@ -440,6 +450,7 @@ def test_attention_tensor_descriptor(self): @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_attention_td_dynamic(self): args = ( torch.randn(1, 32, 512, 64, dtype=torch.float32, device=DEVICE), @@ -460,6 +471,7 @@ def test_attention_td_dynamic(self): @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) + @skipIfMTIA("Not supported on MTIA yet.") def test_minimum_16_byte_block_size_fallback(self): """Test that tensor descriptor falls back when block size is too small.""" diff --git a/test/test_views.py b/test/test_views.py index 9b797c54b..9cd0ead91 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -11,6 +11,7 @@ from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import skipIfCpu +from helion._testing import skipIfMTIA from helion._testing import skipIfPy314 from helion._testing import skipIfRefEager from helion._testing import skipIfRocm @@ -211,6 +212,7 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: _code, result = code_and_output(fn, args) torch.testing.assert_close(result, args[0] + args[1]) + @skipIfMTIA("Not supported on MTIA yet.") def test_split_join_roundtrip(self): @helion.kernel(config={"block_size": 64}) def fn(x: torch.Tensor) -> torch.Tensor: @@ -228,6 +230,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertIn("tl.split", code) self.assertIn("tl.join", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_join_broadcast_scalar(self): @helion.kernel(config={"block_size": 64}) def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -246,6 +249,7 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertIn("tl.join", code) + @skipIfMTIA("Not supported on MTIA yet.") def test_scalar_broadcast_2d(self): """Test that scalars broadcast correctly with 2D tensors.""" @@ -306,6 +310,7 @@ def reshape_reduction_dim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: expected = torch.matmul(x, y) torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + @skipIfMTIA("Not supported on MTIA yet.") def test_reshape_sum(self): @helion.kernel(static_shapes=True) def fn(x: torch.Tensor) -> torch.Tensor: @@ -323,6 +328,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) self.assertExpectedJournal(code) + @skipIfMTIA("Not supported on MTIA yet.") def test_stack_power_of_2(self): @helion.kernel(autotune_effort="none", static_shapes=True) def test_stack_power_of_2_kernel( @@ -361,6 +367,7 @@ def test_stack_power_of_2_kernel( expected[1::2] = b # Every 2nd row starting from 1 torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + @skipIfMTIA("Not supported on MTIA yet.") def test_stack_non_power_of_2(self): @helion.kernel(autotune_effort="none", static_shapes=True) def test_stack_non_power_of_2_kernel( @@ -415,6 +422,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: self.assertExpectedJournal(code) @skipIfPy314("torch.compile not yet supported on Python 3.14") + @skipIfMTIA("Not supported on MTIA yet.") def test_stack_dim0(self): @helion.kernel(autotune_effort="none", static_shapes=True) def test_stack_dim0_kernel( @@ -466,6 +474,7 @@ def capture_graph(graph): ) assert "aten.cat" in self._graph and "aten.stack" not in self._graph + @skipIfMTIA("Not supported on MTIA yet.") def test_view_dtype_reinterpret(self): """Test viewing a tensor with a different dtype (bitcast/reinterpret)."""