-
Notifications
You must be signed in to change notification settings - Fork 89
[autotuner] filter and clone only mutated tensors #1252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Burkhard Ringlein <[email protected]>
|
Hi @bringlein! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR! agreed we should do this. Some inline comments:
helion/autotuner/base_search.py
Outdated
| def _clone_leaf(leaf: object) -> object: | ||
| if isinstance(leaf, torch.Tensor): | ||
| if isinstance(leaf, torch.Tensor) and ( | ||
| all_tensors or leaf.data_ptr() in self._mutated_tensor_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.data_ptr() would work for Nvidia / AMD gpus but I worry that not all hardware have this attribute well-defined. Using "index of tensor in the arg list" is a better differentiator I think.
also since individual arguments in args can also be nested containers of torch.Tensor, we might need to use tree_map like:
count = [0]
tree_map_only(torch.Tensor, lambda t: count.__setitem__(0, count[0] + 1) or t, args)Would be great to add unit tests to cover these mutation cases as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worry that not all hardware have this attribute well-defined.
Good hint, I didn't thought about this.
Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted.
But if I see this here correctly, base_search never handles bound kernels, correct? So this risk doesn't exist here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worry that not all hardware have this attribute well-defined.
Good hint, I didn't thought about this.
Regarding indexes: I did not want to use the index, because out of experience in working with the Triton autotuner, where the compiler then removes constants as arguments at some point, because there are bound to the kernel. Hence, the indexes become shifted. But if I see this here correctly,
base_searchnever handles bound kernels, correct? So this risk doesn't exist here?
Yeah I don't think we ever dynamically remove arguments or dynamically modify the # of tensors in an argument of a Helion kernel function, so it should be safe to use indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds like the index of the tensors would then be the best option, I'll do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yf225 for the hint regarding tree_map. I learned smth while trying to implement it using this function, but in the end I struggled to compare old and new args with only the tree map.
And if I use tree_flatten for the initial comparison, I'm not sure I can trust that the indexes created by tree_map are always the same. So, I used now tree_flatten and tree_unflatten for the proposed solution.
I also added a unit test to test it. Looking forward to your feedback!
Signed-off-by: Burkhard Ringlein <[email protected]>
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Signed-off-by: Burkhard Ringlein <[email protected]>
Context
The autotuner detects if a kernel mutates any of the arguments and if so, clones all of them for each autotuner run.
However, some kernels may modify only one tensor and still all tensors are cloned. If autotuning inside applications with bigger tensors (e.g. a KV cache inside an inference server), this repeatedly leads to OOMs.
Hence I propose to actually only clone the tensors that are mutated (and use the
tensor.data_ptr()as best available identifier?).Test
Please advise if I should run any other test.