-
Notifications
You must be signed in to change notification settings - Fork 89
[autotuner] filter and clone only mutated tensors #1252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
cd28d67
3202c32
13d30c6
5cab1be
f75f3cc
13cc22a
f982b45
aa7525d
1ab53a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,6 +102,7 @@ class BaseSearch(BaseAutotuner): | |
|
|
||
| _baseline_output: object | ||
| _kernel_mutates_args: bool | ||
| _mutated_tensor_args: Sequence[object] | None | ||
| _baseline_post_args: Sequence[object] | None | ||
| _jobs: int | ||
| _precompile_result_counter: count[int] | ||
|
|
@@ -127,15 +128,19 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: | |
| seed = self.settings.autotune_random_seed | ||
| random.seed(seed) | ||
| self.log(f"Autotune random seed: {seed}") | ||
| self._original_args: Sequence[object] = self._clone_args(self.args) | ||
| self._original_args: Sequence[object] = self._clone_args( | ||
| self.args, all_tensors=True | ||
| ) | ||
| self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None | ||
| self._precompile_args_path: str | None = None | ||
| self._precompile_result_counter = count() | ||
| ( | ||
| self._baseline_output, | ||
| self._kernel_mutates_args, | ||
| self._mutated_tensor_args, | ||
| self._baseline_post_args, | ||
| ) = self._compute_baseline() | ||
| print(self._mutated_tensor_args) | ||
bringlein marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._effective_atol, self._effective_rtol = ( | ||
| self._compute_effective_tolerances() | ||
| ) | ||
|
|
@@ -155,17 +160,29 @@ def cleanup(self) -> None: | |
| self._precompile_args_path = None | ||
| self._precompile_result_counter = count() | ||
|
|
||
| def _clone_args(self, args: Sequence[object]) -> Sequence[object]: | ||
| def _clone_args( | ||
| self, args: Sequence[object], all_tensors: bool = False | ||
| ) -> Sequence[object]: | ||
| if ( | ||
| not hasattr(self, "_mutated_tensor_args") | ||
| or self._mutated_tensor_args is None | ||
bringlein marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| all_tensors = True | ||
|
|
||
| def _clone_leaf(leaf: object) -> object: | ||
| if isinstance(leaf, torch.Tensor): | ||
| if isinstance(leaf, torch.Tensor) and ( | ||
| all_tensors or leaf.data_ptr() in self._mutated_tensor_args | ||
|
||
| ): | ||
| clone = leaf.detach().clone() | ||
| clone.requires_grad_(leaf.requires_grad) | ||
| return clone | ||
| return leaf | ||
|
|
||
| return tree_map(_clone_leaf, args) | ||
|
|
||
| def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: | ||
| def _compute_baseline( | ||
| self, | ||
| ) -> tuple[object, bool, Sequence[int], Sequence[object] | None]: | ||
| """ | ||
| Compute baseline output for accuracy validation during autotuning. | ||
| Also detect if the kernel mutates any of its input arguments. | ||
|
|
@@ -174,7 +191,7 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: | |
| - If settings.autotune_baseline_fn is provided, use that custom function | ||
| - Otherwise, run the kernel with the default config | ||
| """ | ||
| new_args = self._clone_args(self._original_args) | ||
| new_args = self._clone_args(self._original_args, all_tensors=True) | ||
|
|
||
| # Use custom baseline function if provided | ||
| if self.settings.autotune_baseline_fn is not None: | ||
|
|
@@ -216,16 +233,17 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: | |
| original_args_flat, _ = tree_flatten(self._original_args) | ||
| new_args_flat, _ = tree_flatten(new_args) | ||
| mutated = False | ||
| mutated_tensors = [] | ||
| for old, new in zip(original_args_flat, new_args_flat, strict=False): | ||
| if ( | ||
| isinstance(old, torch.Tensor) | ||
| and isinstance(new, torch.Tensor) | ||
| and (not torch.equal(new, old)) | ||
| ): | ||
| mutated = True | ||
| break | ||
| mutated_tensors.append(old.data_ptr()) | ||
| baseline_post_args = self._clone_args(new_args) | ||
| return baseline_output, mutated, baseline_post_args | ||
| return baseline_output, mutated, mutated_tensors, baseline_post_args | ||
|
|
||
| def _compute_effective_tolerances(self) -> tuple[float, float]: | ||
| """ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.