diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 041ea6782..d87967529 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -39,6 +39,7 @@ from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map_only +from torch.utils._pytree import tree_unflatten from triton.testing import do_bench from .. import exc @@ -87,6 +88,31 @@ class BenchmarkResult(NamedTuple): compile_time: float | None +def _clone_args( + args: Sequence[object], + idx_to_clone: Sequence[int] | None = None, +) -> Sequence[object]: + """ + Clone the given arguments, but cloning only the tensors specified by + idx_to_clone. If idx_to_clone is None, clone all tensors. + """ + + args_flat, tree_spec = tree_flatten(args) + tensor_idx = 0 + for i, arg in enumerate(args_flat): + if not isinstance(arg, torch.Tensor): + continue + if isinstance(arg, torch.Tensor) and ( + idx_to_clone is None or tensor_idx in idx_to_clone + ): + clone = arg.detach().clone() + clone.requires_grad_(arg.requires_grad) + args_flat[i] = clone + tensor_idx += 1 + + return tree_unflatten(args_flat, tree_spec) + + class BaseSearch(BaseAutotuner): """ Base class for search algorithms. This class defines the interface and utilities for all @@ -101,7 +127,7 @@ class BaseSearch(BaseAutotuner): """ _baseline_output: object - _kernel_mutates_args: bool + _mutated_arg_indicies: Sequence[int] | None _baseline_post_args: Sequence[object] | None _jobs: int _precompile_result_counter: count[int] @@ -127,13 +153,13 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: seed = self.settings.autotune_random_seed random.seed(seed) self.log(f"Autotune random seed: {seed}") - self._original_args: Sequence[object] = self._clone_args(self.args) + self._original_args: Sequence[object] = _clone_args(self.args) self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None self._precompile_args_path: str | None = None self._precompile_result_counter = count() ( self._baseline_output, - self._kernel_mutates_args, + self._mutated_arg_indicies, self._baseline_post_args, ) = self._compute_baseline() self._effective_atol, self._effective_rtol = ( @@ -156,17 +182,9 @@ def cleanup(self) -> None: self._precompile_args_path = None self._precompile_result_counter = count() - def _clone_args(self, args: Sequence[object]) -> Sequence[object]: - def _clone_leaf(leaf: object) -> object: - if isinstance(leaf, torch.Tensor): - clone = leaf.detach().clone() - clone.requires_grad_(leaf.requires_grad) - return clone - return leaf - - return tree_map(_clone_leaf, args) - - def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: + def _compute_baseline( + self, + ) -> tuple[object, Sequence[int] | None, Sequence[object] | None]: """ Compute baseline output for accuracy validation during autotuning. Also detect if the kernel mutates any of its input arguments. @@ -175,7 +193,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: - If settings.autotune_baseline_fn is provided, use that custom function - Otherwise, run the kernel with the default config """ - new_args = self._clone_args(self._original_args) + new_args = _clone_args(self._original_args) # Use custom baseline function if provided if self.settings.autotune_baseline_fn is not None: @@ -217,16 +235,19 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: original_args_flat, _ = tree_flatten(self._original_args) new_args_flat, _ = tree_flatten(new_args) mutated = False + mutated_tensors = [] + # we should only count tensors, since they won't be bound or removed + tensor_idx = 0 for old, new in zip(original_args_flat, new_args_flat, strict=False): - if ( - isinstance(old, torch.Tensor) - and isinstance(new, torch.Tensor) - and (not torch.equal(new, old)) - ): + if not (isinstance(old, torch.Tensor) and isinstance(new, torch.Tensor)): + continue + if not torch.equal(new, old): mutated = True - break - baseline_post_args = self._clone_args(new_args) - return baseline_output, mutated, baseline_post_args + mutated_tensors.append(tensor_idx) + tensor_idx += 1 + baseline_post_args = _clone_args(new_args, idx_to_clone=mutated_tensors) + mutated_tensors = None if not mutated else mutated_tensors + return baseline_output, mutated_tensors, baseline_post_args def _compute_effective_tolerances(self) -> tuple[float, float]: """ @@ -255,7 +276,10 @@ def collect_dtypes(obj: object) -> object: return obj tree_map_only(torch.Tensor, collect_dtypes, self._baseline_output) - if self._kernel_mutates_args and self._baseline_post_args is not None: + if ( + hasattr(self, "_mutated_arg_indicies") + and self._mutated_arg_indicies is not None + ) and self._baseline_post_args is not None: tree_map_only(torch.Tensor, collect_dtypes, self._baseline_post_args) # Check for fp8 dtypes - these require exact bitwise comparison @@ -347,7 +371,10 @@ def _validate_against_baseline( atol=self._effective_atol, rtol=self._effective_rtol, ) - if self._kernel_mutates_args: + if ( + hasattr(self, "_mutated_arg_indicies") + and self._mutated_arg_indicies is not None + ): torch.testing.assert_close( args, self._baseline_post_args, @@ -398,8 +425,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: # TODO(jansel): early exit with fewer trials if early runs are slow self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}") t0 = time.perf_counter() - if self._kernel_mutates_args: - self.args = self._clone_args(self._original_args) + if ( + hasattr(self, "_mutated_arg_indicies") + and self._mutated_arg_indicies is not None + ): + self.args = _clone_args( + self._original_args, idx_to_clone=self._mutated_arg_indicies + ) torch.accelerator.synchronize() output = fn(*self.args) # make sure the kernel is compiled torch.accelerator.synchronize() @@ -498,8 +530,13 @@ def start_precompile_and_check_for_hangs( mode = self.settings.autotune_precompile if mode not in {"fork", "spawn"}: raise exc.InvalidAPIUsage("autotune_precompile must be 'fork' or 'spawn'") - if self._kernel_mutates_args: - device_args = self._clone_args(self._original_args) + if ( + hasattr(self, "_mutated_arg_indicies") + and self._mutated_arg_indicies is not None + ): + device_args = _clone_args( + self._original_args, idx_to_clone=self._mutated_arg_indicies + ) else: device_args = self.args diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 6c68c4785..c7b5f9272 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -1472,6 +1472,43 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # Should have been called with 2 functions self.assertEqual(benchmark_calls[0][0], 2) + def test_autotune_configuration_cloning(self) -> None: + """Tests base_search._clone_cargs function.""" + + config1 = helion.Config(block_sizes=[32, 32], num_warps=4) + config2 = helion.Config(block_sizes=[64, 64], num_warps=8) + + @helion.kernel( + configs=[config1, config2], + autotune_log_level=0, + ) + def nested_in_place_add( + a: Sequence[torch.Tensor], + b: Sequence[torch.Tensor], + out: Sequence[torch.Tensor], + ): + for tile in hl.tile(out[0].size()): + out[0][tile] += a[0][tile] + b[0][tile] + for tile in hl.tile(out[1].size()): + out[1][tile] += a[1][tile] + b[1][tile] + + args = ( + [torch.ones([128], device=DEVICE), torch.ones([128], device=DEVICE)], + [torch.ones([128], device=DEVICE), torch.ones([128], device=DEVICE)], + [torch.zeros([128], device=DEVICE), torch.zeros([128], device=DEVICE)], + ) + + # Run autotuning + nested_in_place_add(*args) + + # test that we overwrite c only once and the arguments are correctly + # cloned for each autotune run + ref_out = [ + torch.full([128], 2.0, device=DEVICE), + torch.full([128], 2.0, device=DEVICE), + ] + torch.testing.assert_close(args[2], ref_out) + class TestAutotuneRandomSeed(RefEagerTestDisabled, TestCase): def _autotune_and_record(self, **settings: object) -> float: