Skip to content

Conversation

@shiyang-weng
Copy link
Contributor

Fuse dq + embeddingbag => _scaled_embedding_bag
Blocked by #3406

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3463

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 8, 2025
@shiyang-weng shiyang-weng marked this pull request as draft December 8, 2025 01:19
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

)
return res

def _q(self, x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _q(self, x):
def _quantize(self, x):

def test_fp8_scaled_embedding_bag(self):
dtype = torch.float8_e4m3fn

def _test_scaled_embedding_bag_helper(self, dtype, with_quant=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _test_scaled_embedding_bag_helper(self, dtype, with_quant=False):
def _test_scaled_embedding_bag_helper(self, dtype, with_output_quant=False):

counters["inductor"]["scaled_embedding_bag_matcher_nodes"] += len(match.nodes)
counter_name = "scaled_embedding_bag"
if "o_dtype" in kwargs:
counter_name += "_with_quant"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
counter_name += "_with_quant"
counter_name += "_with_output_quant"

def matcher_check_fn():
counter_name = "scaled_embedding_bag"
if with_quant:
counter_name += "_with_quant"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
counter_name += "_with_quant"
counter_name += "_with_output_quant"

self.weight_scale = 2.0
self.output_scale = 3.0

def _dq(self, weight):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _dq(self, weight):
def _dequantize(self, weight):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants