Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class BaseSearch(BaseAutotuner):

_baseline_output: object
_kernel_mutates_args: bool
_mutated_tensor_args: Sequence[object] | None
_baseline_post_args: Sequence[object] | None
_jobs: int
_precompile_result_counter: count[int]
Expand All @@ -127,15 +128,19 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
seed = self.settings.autotune_random_seed
random.seed(seed)
self.log(f"Autotune random seed: {seed}")
self._original_args: Sequence[object] = self._clone_args(self.args)
self._original_args: Sequence[object] = self._clone_args(
self.args, all_tensors=True
)
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
self._precompile_args_path: str | None = None
self._precompile_result_counter = count()
(
self._baseline_output,
self._kernel_mutates_args,
self._mutated_tensor_args,
self._baseline_post_args,
) = self._compute_baseline()
print(self._mutated_tensor_args)
self._effective_atol, self._effective_rtol = (
self._compute_effective_tolerances()
)
Expand All @@ -155,17 +160,29 @@ def cleanup(self) -> None:
self._precompile_args_path = None
self._precompile_result_counter = count()

def _clone_args(self, args: Sequence[object]) -> Sequence[object]:
def _clone_args(
self, args: Sequence[object], all_tensors: bool = False
) -> Sequence[object]:
if (
not hasattr(self, "_mutated_tensor_args")
or self._mutated_tensor_args is None
):
all_tensors = True

def _clone_leaf(leaf: object) -> object:
if isinstance(leaf, torch.Tensor):
if isinstance(leaf, torch.Tensor) and (
all_tensors or leaf.data_ptr() in self._mutated_tensor_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.data_ptr() would work for Nvidia / AMD gpus but I worry that not all hardware have this attribute well-defined. Using "index of tensor in the arg list" is a better differentiator I think.

also since individual arguments in args can also be nested containers of torch.Tensor, we might need to use tree_map like:

count = [0]
tree_map_only(torch.Tensor, lambda t: count.__setitem__(0, count[0] + 1) or t, args)

Would be great to add unit tests to cover these mutation cases as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that not all hardware have this attribute well-defined.

Good hint, I didn't thought about this.

Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted.
But if I see this here correctly, base_search never handles bound kernels, correct? So this risk doesn't exist here?

Copy link
Contributor

@yf225 yf225 Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that not all hardware have this attribute well-defined.

Good hint, I didn't thought about this.

Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted. But if I see this here correctly, base_search never handles bound kernels, correct? So this risk doesn't exist here?

Yeah I don't think we ever dynamically remove arguments or dynamically modify the # of tensors in an argument of a Helion kernel function, so it should be safe to use indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds like the index of the tensors would then be the best option, I'll do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yf225 for the hint regarding tree_map. I learned smth while trying to implement it using this function, but in the end I struggled to compare old and new args with only the tree map.

And if I use tree_flatten for the initial comparison, I'm not sure I can trust that the indexes created by tree_map are always the same. So, I used now tree_flatten and tree_unflatten for the proposed solution.

I also added a unit test to test it. Looking forward to your feedback!

):
clone = leaf.detach().clone()
clone.requires_grad_(leaf.requires_grad)
return clone
return leaf

return tree_map(_clone_leaf, args)

def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
def _compute_baseline(
self,
) -> tuple[object, bool, Sequence[int], Sequence[object] | None]:
"""
Compute baseline output for accuracy validation during autotuning.
Also detect if the kernel mutates any of its input arguments.
Expand All @@ -174,7 +191,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
- If settings.autotune_baseline_fn is provided, use that custom function
- Otherwise, run the kernel with the default config
"""
new_args = self._clone_args(self._original_args)
new_args = self._clone_args(self._original_args, all_tensors=True)

# Use custom baseline function if provided
if self.settings.autotune_baseline_fn is not None:
Expand Down Expand Up @@ -216,16 +233,17 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
original_args_flat, _ = tree_flatten(self._original_args)
new_args_flat, _ = tree_flatten(new_args)
mutated = False
mutated_tensors = []
for old, new in zip(original_args_flat, new_args_flat, strict=False):
if (
isinstance(old, torch.Tensor)
and isinstance(new, torch.Tensor)
and (not torch.equal(new, old))
):
mutated = True
break
mutated_tensors.append(old.data_ptr())
baseline_post_args = self._clone_args(new_args)
return baseline_output, mutated, baseline_post_args
return baseline_output, mutated, mutated_tensors, baseline_post_args

def _compute_effective_tolerances(self) -> tuple[float, float]:
"""
Expand Down