Skip to content

Commit e17c285

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Generalize quantized input wrapper (#16202)
Summary: Adds dequant nodes to support int8 input if specified Differential Revision: D88810482
1 parent 5033840 commit e17c285

File tree

2 files changed

+116
-5
lines changed

2 files changed

+116
-5
lines changed

backends/cadence/aot/compiler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from executorch.backends.cadence.aot.compiler_funcs import (
1616
prepare as prepare_fn,
17+
QuantizedInputWrapper,
1718
trace as trace_fn,
1819
)
1920
from executorch.backends.cadence.aot.memory_planning import (
@@ -39,12 +40,10 @@
3940
from executorch.exir.passes import ToOutVarPass
4041
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
4142
from executorch.exir.program._program import _transform, to_edge
42-
4343
from torch.export.exported_program import ExportedProgram
4444
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e
4545

4646
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
47-
4847
from .utils import print_ops_info
4948

5049
default_quantizer = CadenceDefaultQuantizer()
@@ -204,6 +203,7 @@ def quantize_pt2(
204203
quantizer: Optional[CadenceQuantizer] = None,
205204
calibration_data: Optional[list[tuple[object, ...]]] = None,
206205
dump_graphs: bool = False,
206+
quant_input_args: Optional[list[str]] = None,
207207
) -> ExportedProgram:
208208
"""
209209
Trace, prepare, convert and fuse the model using the given quantizer.
@@ -226,9 +226,11 @@ def quantize_pt2(
226226
calibration_data=calibration_data,
227227
dump_graphs=dump_graphs,
228228
)
229+
# Wrap the model to handle quantized inputs
230+
wrapped_module = QuantizedInputWrapper(converted_gm, quant_input_args).module
229231

230232
# Apply quant fusion to the exported program
231-
program = torch.export.export(converted_gm, inputs, strict=True)
233+
program = torch.export.export(wrapped_module, inputs, strict=True)
232234
fused_program = apply_pre_edge_transform_passes(program, quantizer)
233235

234236
if dump_graphs:

backends/cadence/aot/compiler_funcs.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66

77
# pyre-strict
88

9-
10-
from typing import Optional
9+
import logging
10+
from typing import Any, Optional, Union
1111

1212
import torch
1313
from torch._inductor.decomposition import remove_decompositions
14+
from torch.fx import GraphModule
1415
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e
1516
from torchao.quantization.pt2e.quantizer import Quantizer
1617

18+
logger: logging.Logger = logging.getLogger(__name__)
19+
QuantArgs = tuple[float, int, int, int, torch.dtype]
20+
1721

1822
@torch.no_grad()
1923
def trace(
@@ -52,3 +56,108 @@ def prepare(
5256
prepared_model = prepare_pt2e(traced_model, quantizer)
5357

5458
return prepared_model
59+
60+
61+
def extract_input_quant_params_from_graph(
62+
module: GraphModule,
63+
input_names: list[str],
64+
) -> dict[int, QuantArgs]:
65+
"""
66+
Extract quantization parameters from the FX graph for model inputs.
67+
"""
68+
quant_args: dict[int, QuantArgs] = {}
69+
found_names: set[str] = set()
70+
71+
if not input_names:
72+
return quant_args
73+
74+
for idx, name in enumerate(input_names):
75+
for node in module.graph.nodes:
76+
if node.op != "call_function":
77+
continue
78+
79+
if (
80+
node.args
81+
and isinstance(node.args[0], torch.fx.Node)
82+
and node.args[0].name == name
83+
and not node.name.startswith("_assert_tensor_metadata")
84+
and "quantize_per_tensor" in str(node.target)
85+
):
86+
args = node.args[1:]
87+
if len(args) >= 5:
88+
quant_args[idx] = (
89+
float(args[0]), # scale
90+
int(args[1]), # zero_point
91+
int(args[2]), # qmin
92+
int(args[3]), # qmax
93+
args[4], # dtype
94+
)
95+
found_names.add(name)
96+
break
97+
98+
missing_names = set(input_names) - found_names
99+
if missing_names:
100+
raise ValueError(
101+
f"Could not find quantization parameters for input(s): {sorted(missing_names)}. "
102+
f"Make sure these input names exist in the graph and quantization parameters."
103+
)
104+
105+
return quant_args
106+
107+
108+
class QuantizedInputWrapper(torch.nn.Module):
109+
"""
110+
Wrapper that allows a quantized model to accept quantized inputs.
111+
112+
If no input_names or quant_args are provided, the wrapper passes inputs
113+
through unchanged (no dequantization).
114+
115+
Args:
116+
module: The quantized GraphModule to wrap.
117+
input_names: Optional list of input placeholder names in the graph.
118+
If provided, extracts quant params from graph.
119+
quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype).
120+
If provided, uses these directly instead of extracting from graph.
121+
122+
Example:
123+
# Extract from graph
124+
wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"])
125+
126+
# Explicit quant args
127+
wrapper = QuantizedInputWrapper(
128+
quantized_module,
129+
quant_args={0: (1/255, 0, 0, 255, torch.uint8)},
130+
)
131+
"""
132+
133+
def __init__(
134+
self,
135+
module: GraphModule,
136+
input_args: Optional[Union[list[str], dict[int, QuantArgs]]] = None,
137+
) -> None:
138+
super().__init__()
139+
self.module: GraphModule = module
140+
self.quant_args: dict[int, QuantArgs] = {}
141+
142+
if input_args is not None:
143+
logger.warning(
144+
"Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed."
145+
"Incorrect quantization parameters can lead to significant accuracy degradation."
146+
)
147+
if isinstance(input_args, list):
148+
self.quant_args = extract_input_quant_params_from_graph(module, input_args)
149+
elif isinstance(input_args, dict):
150+
self.quant_args = input_args
151+
152+
def forward(self, *args: torch.Tensor) -> Any:
153+
"""Run inference, dequantizing configured inputs."""
154+
dequantized_args = []
155+
for index, node in enumerate(args):
156+
if index in self.quant_args:
157+
scale, zp, qmin, qmax, dtype = self.quant_args[index]
158+
node = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
159+
node, scale, zp, qmin, qmax, dtype
160+
)
161+
dequantized_args.append(node)
162+
163+
return self.module(*dequantized_args)

0 commit comments

Comments
 (0)