Skip to content

Conversation

@bringlein
Copy link

Context

The autotuner detects if a kernel mutates any of the arguments and if so, clones all of them for each autotuner run.
However, some kernels may modify only one tensor and still all tensors are cloned. If autotuning inside applications with bigger tensors (e.g. a KV cache inside an inference server), this repeatedly leads to OOMs.

Hence I propose to actually only clone the tensors that are mutated (and use the tensor.data_ptr() as best available identifier?).

Test

python3 ./vllm-triton-backend/helion/test/test_autotuner.py 
....
...
----------------------------------------------------------------------
Ran 46 tests in 17.525s

OK (skipped=4)

Please advise if I should run any other test.

Signed-off-by: Burkhard Ringlein <[email protected]>
@meta-cla
Copy link

meta-cla bot commented Dec 11, 2025

Hi @bringlein!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! agreed we should do this. Some inline comments:

def _clone_leaf(leaf: object) -> object:
if isinstance(leaf, torch.Tensor):
if isinstance(leaf, torch.Tensor) and (
all_tensors or leaf.data_ptr() in self._mutated_tensor_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.data_ptr() would work for Nvidia / AMD gpus but I worry that not all hardware have this attribute well-defined. Using "index of tensor in the arg list" is a better differentiator I think.

also since individual arguments in args can also be nested containers of torch.Tensor, we might need to use tree_map like:

count = [0]
tree_map_only(torch.Tensor, lambda t: count.__setitem__(0, count[0] + 1) or t, args)

Would be great to add unit tests to cover these mutation cases as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that not all hardware have this attribute well-defined.

Good hint, I didn't thought about this.

Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted.
But if I see this here correctly, base_search never handles bound kernels, correct? So this risk doesn't exist here?

Copy link
Contributor

@yf225 yf225 Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that not all hardware have this attribute well-defined.

Good hint, I didn't thought about this.

Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted. But if I see this here correctly, base_search never handles bound kernels, correct? So this risk doesn't exist here?

Yeah I don't think we ever dynamically remove arguments or dynamically modify the # of tensors in an argument of a Helion kernel function, so it should be safe to use indices.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds like the index of the tensors would then be the best option, I'll do it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yf225 for the hint regarding tree_map. I learned smth while trying to implement it using this function, but in the end I struggled to compare old and new args with only the tree map.

And if I use tree_flatten for the initial comparison, I'm not sure I can trust that the indexes created by tree_map are always the same. So, I used now tree_flatten and tree_unflatten for the proposed solution.

I also added a unit test to test it. Looking forward to your feedback!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 15, 2025
@meta-cla
Copy link

meta-cla bot commented Dec 15, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants