@@ -759,6 +759,75 @@ def masked_store(x: torch.Tensor, *, _launcher=_default_launcher):
759759 # src[test_indexing.py:N]: return out
760760 return out
761761
762+ --- assertExpectedJournal(TestIndexing.test_mixed_scalar_block_store_size1_dim)
763+ from __future__ import annotations
764+
765+ import torch
766+ import helion.language as hl
767+ import triton
768+ import triton.language as tl
769+ from torch._inductor.runtime import triton_helpers
770+ from torch._inductor.runtime.triton_helpers import math as tl_math
771+ from helion.runtime import default_launcher as _default_launcher
772+
773+ @triton.jit
774+ def _helion_kernel_with_mixed_store(x_data, out, scales, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
775+ # src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
776+ num_blocks_0 = 1
777+ pid_1 = tl.program_id(0) // num_blocks_0
778+ offset_0 = pid_1 * _BLOCK_SIZE_0
779+ # src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
780+ tile_end = offset_0 + _BLOCK_SIZE_0
781+ # src[test_indexing.py:N]: for n_tile_local in hl.tile(
782+ # src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
783+ # src[test_indexing.py:N]: ):
784+ # src[test_indexing.py:N-N]: ...
785+ for offset_2 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_2):
786+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
787+ mask_2 = indices_2 < tile_end
788+ # src[test_indexing.py:N]: x_block = x_data[m_tile, n_tile_local]
789+ x_block = tl.load(x_data + indices_2[None, :] * 1, mask_2[None, :], other=0)
790+ # src[test_indexing.py:N]: row_max = x_block.abs().amax(dim=1)
791+ v_0 = tl_math.abs(x_block)
792+ _mask_to = tl.where(tl.broadcast_to(mask_2[None, :], [1, _BLOCK_SIZE_2]), v_0, tl.full([], float('-inf'), tl.float32))
793+ row_max = tl.cast(tl.max(_mask_to, 1), tl.float32)
794+ # src[test_indexing.py:N]: row_value = row_max.to(torch.uint8)
795+ v_1 = tl.cast(row_max, tl.uint8)
796+ # src[test_indexing.py:N]: out[m_tile, n_tile_local] = x_block * 2.0
797+ v_2 = 2.0
798+ v_3 = x_block * v_2
799+ tl.store(out + indices_2[None, :] * 1, v_3, mask_2[None, :])
800+ # src[test_indexing.py:N]: scale_col_idx = n_tile_local.begin // BLOCK_SIZE # scalar
801+ floordiv = triton_helpers.div_floor_integer(offset_2, 32)
802+ # src[test_indexing.py:N]: scales[m_tile, scale_col_idx] = row_value # row_value is block
803+ tl.store(scales + floordiv * 1, tl.reshape(v_1, []), None)
804+
805+ def kernel_with_mixed_store(x_data: torch.Tensor, BLOCK_SIZE: hl.constexpr, *, _launcher=_default_launcher):
806+ # src[test_indexing.py:N]: m, n = x_data.shape
807+ m, n = x_data.shape
808+ # src[test_indexing.py:N]: n = hl.specialize(n)
809+ n = 64
810+ # src[test_indexing.py:N]: n_scale_cols = (n + BLOCK_SIZE - 1) // BLOCK_SIZE
811+ n_scale_cols = (n + 32 - 1) // 32
812+ # src[test_indexing.py:N]: scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
813+ scales = x_data.new_empty((m, n_scale_cols), dtype=torch.uint8)
814+ # src[test_indexing.py:N]: out = x_data.new_empty(x_data.shape, dtype=torch.float32)
815+ out = x_data.new_empty(x_data.shape, dtype=torch.float32)
816+ # src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
817+ _BLOCK_SIZE_0 = 32
818+ # src[test_indexing.py:N]: for n_tile_local in hl.tile(
819+ # src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
820+ # src[test_indexing.py:N]: ):
821+ # src[test_indexing.py:N-N]: ...
822+ _BLOCK_SIZE_2 = 32
823+ # src[test_indexing.py:N]: for m_tile, n_tile in hl.tile([m, n], block_size=[None, n_block]):
824+ # src[test_indexing.py:N]: for n_tile_local in hl.tile(
825+ # src[test_indexing.py:N]: n_tile.begin, n_tile.end, block_size=BLOCK_SIZE
826+ # src[test_indexing.py:N-N]: ...
827+ _launcher(_helion_kernel_with_mixed_store, (1 * triton.cdiv(64, _BLOCK_SIZE_0),), x_data, out, scales, _BLOCK_SIZE_0, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
828+ # src[test_indexing.py:N]: return out, scales
829+ return (out, scales)
830+
762831--- assertExpectedJournal(TestIndexing.test_non_consecutive_tensor_indexers_no_broadcast)
763832from __future__ import annotations
764833
0 commit comments