diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index d19abf222b8..5770b05ad1e 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -14,6 +14,7 @@ import torch from executorch.backends.cadence.aot.compiler_funcs import ( prepare as prepare_fn, + QuantizedInputWrapper, trace as trace_fn, ) from executorch.backends.cadence.aot.memory_planning import ( @@ -39,12 +40,10 @@ from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.program._program import _transform, to_edge - from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e from .passes import apply_exir_ops_passes, apply_torch_ops_passes - from .utils import print_ops_info default_quantizer = CadenceDefaultQuantizer() @@ -204,6 +203,7 @@ def quantize_pt2( quantizer: Optional[CadenceQuantizer] = None, calibration_data: Optional[list[tuple[object, ...]]] = None, dump_graphs: bool = False, + quant_input_args: Optional[list[str]] = None, ) -> ExportedProgram: """ Trace, prepare, convert and fuse the model using the given quantizer. @@ -226,9 +226,11 @@ def quantize_pt2( calibration_data=calibration_data, dump_graphs=dump_graphs, ) + # Wrap the model to handle quantized inputs + wrapped_module = QuantizedInputWrapper(converted_gm, quant_input_args).module # Apply quant fusion to the exported program - program = torch.export.export(converted_gm, inputs, strict=True) + program = torch.export.export(wrapped_module, inputs, strict=True) fused_program = apply_pre_edge_transform_passes(program, quantizer) if dump_graphs: diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 6ff6057255c..9756602ad2d 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -6,14 +6,18 @@ # pyre-strict - -from typing import Optional +import logging +from typing import Any, Optional, Union import torch from torch._inductor.decomposition import remove_decompositions +from torch.fx import GraphModule from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e from torchao.quantization.pt2e.quantizer import Quantizer +logger: logging.Logger = logging.getLogger(__name__) +QuantArgs = tuple[float, int, int, int, torch.dtype] + @torch.no_grad() def trace( @@ -52,3 +56,108 @@ def prepare( prepared_model = prepare_pt2e(traced_model, quantizer) return prepared_model + + +def extract_input_quant_params_from_graph( + module: GraphModule, + input_names: list[str], +) -> dict[int, QuantArgs]: + """ + Extract quantization parameters from the FX graph for model inputs. + """ + quant_args: dict[int, QuantArgs] = {} + found_names: set[str] = set() + + if not input_names: + return quant_args + + for idx, name in enumerate(input_names): + for node in module.graph.nodes: + if node.op != "call_function": + continue + + if ( + node.args + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].name == name + and not node.name.startswith("_assert_tensor_metadata") + and "quantize_per_tensor" in str(node.target) + ): + args = node.args[1:] + if len(args) >= 5: + quant_args[idx] = ( + float(args[0]), # scale + int(args[1]), # zero_point + int(args[2]), # qmin + int(args[3]), # qmax + args[4], # dtype + ) + found_names.add(name) + break + + missing_names = set(input_names) - found_names + if missing_names: + raise ValueError( + f"Could not find quantization parameters for input(s): {sorted(missing_names)}. " + f"Make sure these input names exist in the graph and quantization parameters." + ) + + return quant_args + + +class QuantizedInputWrapper(torch.nn.Module): + """ + Wrapper that allows a quantized model to accept quantized inputs. + + If no input_names or quant_args are provided, the wrapper passes inputs + through unchanged (no dequantization). + + Args: + module: The quantized GraphModule to wrap. + input_names: Optional list of input placeholder names in the graph. + If provided, extracts quant params from graph. + quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype). + If provided, uses these directly instead of extracting from graph. + + Example: + # Extract from graph + wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"]) + + # Explicit quant args + wrapper = QuantizedInputWrapper( + quantized_module, + quant_args={0: (1/255, 0, 0, 255, torch.uint8)}, + ) + """ + + def __init__( + self, + module: GraphModule, + input_args: Optional[Union[list[str], dict[int, QuantArgs]]] = None, + ) -> None: + super().__init__() + self.module: GraphModule = module + self.quant_args: dict[int, QuantArgs] = {} + + if input_args is not None: + logger.warning( + "Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed." + "Incorrect quantization parameters can lead to significant accuracy degradation." + ) + if isinstance(input_args, list): + self.quant_args = extract_input_quant_params_from_graph(module, input_args) + elif isinstance(input_args, dict): + self.quant_args = input_args + + def forward(self, *args: torch.Tensor) -> Any: + """Run inference, dequantizing configured inputs.""" + dequantized_args = [] + for index, node in enumerate(args): + if index in self.quant_args: + scale, zp, qmin, qmax, dtype = self.quant_args[index] + node = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + node, scale, zp, qmin, qmax, dtype + ) + dequantized_args.append(node) + + return self.module(*dequantized_args)