Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 115 additions & 5 deletions helion/language/ref_tile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import itertools
import traceback
from typing import TYPE_CHECKING
from typing import TypeVar

Expand All @@ -16,6 +18,22 @@

_T = TypeVar("_T")

# Counter for generating unique block_ids in ref mode
_ref_mode_block_id_counter = itertools.count()

# Dict to map tensor id -> block_ids for tracking (cleared at kernel start)
_tensor_block_ids: dict[int, tuple[int | None, ...]] = {}

# Patterns indicating library/framework code (not user code)
_LIBRARY_PATH_PATTERNS = (
"/helion/helion/",
"/torch/",
"/unittest/",
"/pytest/",
"/site-packages/",
"<frozen",
)


_ADD_OPS: set[object] = {
torch.add,
Expand All @@ -42,15 +60,21 @@
class RefTile(TileInterface, torch.Tensor):
_slice: slice
_block_size: int
_block_id: int

def __init__(self, begin: int, end: int, block_size: int) -> None:
def __init__(
self, begin: int, end: int, block_size: int, block_id: int | None = None
) -> None:
super().__init__()

from ..runtime.ref_mode import is_in_ref_mode_context

assert is_in_ref_mode_context()
self._slice = slice(begin, end, None)
self._block_size = block_size
self._block_id = block_id if block_id is not None else next(
_ref_mode_block_id_counter
)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -150,13 +174,30 @@ def _handle_getitem(
args: tuple[object, ...],
kwargs: dict[str, object] | None,
) -> object:
"""Handle tensor[index] operations."""
"""Handle tensor[index] operations with tile indices."""
tensor, index = args
assert isinstance(tensor, torch.Tensor)

# Extract block_ids from RefTile indices
indices = index if isinstance(index, tuple) else (index,)
block_ids: list[int | None] = []
for idx in indices:
if isinstance(idx, RefTile):
block_ids.append(idx._block_id)
elif not isinstance(idx, int): # slice or other -> adds a dim
block_ids.append(None)
# int indices reduce dims, so don't append

slice_index = convert_tile_indices_to_slices(index)
# pyrefly: ignore [bad-index]
return tensor[slice_index]
result = tensor[slice_index]

# Register result with block_ids for tracking
if block_ids and isinstance(result, torch.Tensor) and result.ndim > 0:
if len(block_ids) == result.ndim:
_tensor_block_ids[id(result)] = tuple(block_ids)

return result

@classmethod
def _handle_setitem(
Expand All @@ -174,7 +215,6 @@ def _handle_setitem(
# pyrefly: ignore [bad-index]
target_shape = tensor[slice_index].shape

# Slice value tensor to match target shape if needed
if (
isinstance(value, torch.Tensor)
and value.shape != target_shape
Expand All @@ -199,6 +239,76 @@ def index(self) -> torch.Tensor: # pyrefly: ignore [bad-override]
from .._compiler.compile_environment import CompileEnvironment

env = CompileEnvironment.current()
return torch.arange(
data = torch.arange(
self._slice.start, self._slice.stop, dtype=torch.int32, device=env.device
)
_tensor_block_ids[id(data)] = (self._block_id,)
return data


def reset_ref_mode_block_id_counter() -> None:
"""Reset the block_id counter and tracking dict. Called at the start of each ref mode kernel execution."""
global _ref_mode_block_id_counter
_ref_mode_block_id_counter = itertools.count()
_tensor_block_ids.clear()


def get_block_ids(tensor: torch.Tensor) -> tuple[int | None, ...] | None:
"""Get block_ids for a tensor if tracked."""
return _tensor_block_ids.get(id(tensor))


def maybe_set_block_ids(tensor: object, block_ids: tuple[int | None, ...] | None) -> None:
"""Set block_ids for a tensor if block_ids is non-empty and matches tensor ndim."""
if block_ids and isinstance(tensor, torch.Tensor) and len(block_ids) == tensor.ndim:
_tensor_block_ids[id(tensor)] = block_ids


def check_broadcast_and_get_result_block_ids(
tensors: list[torch.Tensor],
) -> tuple[int | None, ...] | None:
"""Check broadcast compatibility and return result block_ids."""
# Get tracked tensors (those with block_ids)
tracked: list[tuple[torch.Tensor, tuple[int | None, ...]]] = []
for t in tensors:
bids = _tensor_block_ids.get(id(t))
if bids is not None:
tracked.append((t, bids))

if not tracked:
return None

shapes = [[*t.shape] for t, _ in tracked]
bids = [[*b] for _, b in tracked]
max_rank = max(len(s) for s in shapes)

# Right-align with padding
for i in range(len(shapes)):
pad = max_rank - len(shapes[i])
shapes[i] = [1] * pad + shapes[i]
bids[i] = [None] * pad + bids[i]

result: list[int | None] = []
for d in range(max_rank):
ids_in_dim = {bids[i][d] for i in range(len(tracked)) if shapes[i][d] != 1 and bids[i][d] is not None}
if len(ids_in_dim) >= 2:
_raise_mismatch(d, shapes, bids, ids_in_dim)
result.append(next(iter(ids_in_dim)) if ids_in_dim else None)
return tuple(result)


def _raise_mismatch(
dim: int, shapes: list[list[int]], bids: list[list[int | None]], ids_in_dim: set[int],
) -> None:
"""Raise ShapeMismatch with location info."""
fmt = lambda s, b: "[" + ", ".join(f"u{x}" if x is not None else str(y) for y, x in zip(s, b, strict=False)) + "]"
descs = [f"tensor with shape {fmt(s, b)}" for s, b in zip(shapes, bids, strict=False)
if s[dim] != 1 and b[dim] in ids_in_dim][:2]

loc = ""
for f in reversed(traceback.extract_stack()):
if not any(p in f.filename for p in _LIBRARY_PATH_PATTERNS):
loc = f"\n at {f.filename}:{f.lineno}: {f.line}"
break

raise exc.ShapeMismatch(descs[0] if descs else "unknown", (descs[1] if len(descs) > 1 else "unknown") + loc)
102 changes: 94 additions & 8 deletions helion/runtime/ref_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from .._compiler.compile_environment import tls as ce_tls
from .._utils import convert_size_arg
from .._utils import create_shape_matching_slices
from ..language.ref_tile import check_broadcast_and_get_result_block_ids
from ..language.ref_tile import get_block_ids
from ..language.ref_tile import maybe_set_block_ids
from ..language.ref_tile import reset_ref_mode_block_id_counter

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -73,6 +77,7 @@ def __enter__(self) -> Self:
assert getattr(ref_mode_tls, "context", None) is None, (
"RefModeContext already active"
)
reset_ref_mode_block_id_counter()
ce_tls.env = self.env
ref_mode_tls.context = self
self.func_mode.__enter__()
Expand Down Expand Up @@ -190,7 +195,8 @@ def __torch_function__(
if func in self._binary_ops:
return self._handle_binary_op(func, args, kwargs)

return super().__torch_function__(func, types, args, kwargs)
# For all other ops, run and propagate block_ids
return self._run_with_block_id_tracking(func, types, args, kwargs)

def _handle_mm_with_bias(
self,
Expand Down Expand Up @@ -295,10 +301,19 @@ def _handle_binary_op(

# Skip if either operand is not a tensor (e.g., scalar operations)
if not (isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor)):
return cast("Callable[..., torch.Tensor]", func)(*args, **kwargs)
result = cast("Callable[..., torch.Tensor]", func)(*args, **kwargs)
# Propagate block_ids for tensor + scalar
if isinstance(lhs, torch.Tensor):
maybe_set_block_ids(result, get_block_ids(lhs))
return result

# Check broadcast compatibility (may raise ShapeMismatch)
result_bids = check_broadcast_and_get_result_block_ids([lhs, rhs])

if not self._should_handle_binary_op(lhs, rhs):
return cast("Callable[..., torch.Tensor]", func)(*args, **kwargs)
result = cast("Callable[..., torch.Tensor]", func)(*args, **kwargs)
maybe_set_block_ids(result, result_bids)
return result

# Check if this is an in-place operation
func_name = getattr(func, "__name__", "")
Expand All @@ -315,9 +330,10 @@ def _handle_binary_op(
lhs[slices], rhs[slices], *args[2:], **kwargs
)

# For in-place ops, the operation already modified lhs, so just return it
# For out-of-place ops, return the computed result
return lhs if is_inplace else result
# For in-place ops, return lhs; for out-of-place ops, return result
final_result = lhs if is_inplace else result
maybe_set_block_ids(final_result, result_bids)
return final_result

def _should_handle_binary_op(self, lhs: object, rhs: object) -> bool:
"""Check if binary operation needs special handling.
Expand Down Expand Up @@ -349,17 +365,47 @@ def _handle_getitem(
args: tuple[object, ...],
kwargs: dict[str, object],
) -> torch.Tensor:
"""Handle tensor indexing with out-of-bounds index clamping."""
"""Handle tensor indexing with out-of-bounds clamping and block_id tracking."""
tensor = cast("torch.Tensor", args[0])
indices: Any = args[1]

# First check if the tensor has block_ids that need to be propagated
tensor_bids = get_block_ids(tensor)

is_tuple = isinstance(indices, tuple)
indices_list = list(indices) if is_tuple else [indices]

for dim, idx in enumerate(indices_list):
if self._is_int_tensor(idx):
indices_list[dim] = torch.clamp(idx, min=0, max=tensor.size(dim) - 1)

return tensor[tuple(indices_list) if is_tuple else indices_list[0]]
result = tensor[tuple(indices_list) if is_tuple else indices_list[0]]

# Propagate block_ids through indexing
if tensor_bids is not None:
bids = list(tensor_bids)
if not is_tuple:
if indices is None:
new_bids = [None, *bids]
elif isinstance(indices, int):
new_bids = bids[1:]
else:
new_bids = bids
else:
new_bids = []
dim = 0
for idx in indices:
if idx is None:
new_bids.append(None)
elif isinstance(idx, int):
dim += 1
else:
if dim < len(bids):
new_bids.append(bids[dim])
dim += 1
maybe_set_block_ids(result, tuple(new_bids))

return result

def _handle_setitem(
self,
Expand Down Expand Up @@ -393,6 +439,46 @@ def _handle_setitem(

tensor[final_indices] = value

def _run_with_block_id_tracking(
self,
func: Callable[..., object],
types: list[type[object]],
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
"""Run operation and propagate block_ids through the result."""
# Collect all input tensors
input_tensors = [x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)]

# Check for reductions
func_name = getattr(func, "__name__", "")
if func_name in ("sum", "mean", "prod", "max", "min", "std", "var", "any", "all"):
if args and isinstance(args[0], torch.Tensor):
tensor = args[0]
tensor_bids = get_block_ids(tensor)
if tensor_bids is not None:
dim = args[1] if len(args) > 1 else kwargs.get("dim")
result = super().__torch_function__(func, types, args, kwargs)
if dim is not None:
bids = list(tensor_bids)
dims = {dim} if isinstance(dim, int) else set(dim) if isinstance(dim, (list, tuple)) else set()
dims = {d if d >= 0 else len(bids) + d for d in dims}
keepdim = kwargs.get("keepdim", False)
if keepdim:
new_bids = [None if i in dims else b for i, b in enumerate(bids)]
else:
new_bids = [b for i, b in enumerate(bids) if i not in dims]
maybe_set_block_ids(result, tuple(new_bids))
return result

# Check broadcast compatibility (may raise ShapeMismatch)
result_bids = check_broadcast_and_get_result_block_ids(input_tensors)

# Run the operation
result = super().__torch_function__(func, types, args, kwargs)
maybe_set_block_ids(result, result_bids)
return result

def _setup_binary_ops_handling(self) -> None:
"""Initialize binary operation tracking sets and mappings."""
# Define binary operations and their variants
Expand Down
Loading
Loading