|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | - |
10 | | -from typing import Optional |
| 9 | +import logging |
| 10 | +from typing import Any, Optional, Union |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from torch._inductor.decomposition import remove_decompositions |
| 14 | +from torch.fx import GraphModule |
14 | 15 | from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e |
15 | 16 | from torchao.quantization.pt2e.quantizer import Quantizer |
16 | 17 |
|
| 18 | +logger: logging.Logger = logging.getLogger(__name__) |
| 19 | +QuantArgs = tuple[float, int, int, int, torch.dtype] |
| 20 | + |
17 | 21 |
|
18 | 22 | @torch.no_grad() |
19 | 23 | def trace( |
@@ -52,3 +56,108 @@ def prepare( |
52 | 56 | prepared_model = prepare_pt2e(traced_model, quantizer) |
53 | 57 |
|
54 | 58 | return prepared_model |
| 59 | + |
| 60 | + |
| 61 | +def extract_input_quant_params_from_graph( |
| 62 | + module: GraphModule, |
| 63 | + input_names: list[str], |
| 64 | +) -> dict[int, QuantArgs]: |
| 65 | + """ |
| 66 | + Extract quantization parameters from the FX graph for model inputs. |
| 67 | + """ |
| 68 | + quant_args: dict[int, QuantArgs] = {} |
| 69 | + found_names: set[str] = set() |
| 70 | + |
| 71 | + if not input_names: |
| 72 | + return quant_args |
| 73 | + |
| 74 | + for idx, name in enumerate(input_names): |
| 75 | + for node in module.graph.nodes: |
| 76 | + if node.op != "call_function": |
| 77 | + continue |
| 78 | + |
| 79 | + if ( |
| 80 | + node.args |
| 81 | + and isinstance(node.args[0], torch.fx.Node) |
| 82 | + and node.args[0].name == name |
| 83 | + and not node.name.startswith("_assert_tensor_metadata") |
| 84 | + and "quantize_per_tensor" in str(node.target) |
| 85 | + ): |
| 86 | + args = node.args[1:] |
| 87 | + if len(args) >= 5: |
| 88 | + quant_args[idx] = ( |
| 89 | + float(args[0]), # scale |
| 90 | + int(args[1]), # zero_point |
| 91 | + int(args[2]), # qmin |
| 92 | + int(args[3]), # qmax |
| 93 | + args[4], # dtype |
| 94 | + ) |
| 95 | + found_names.add(name) |
| 96 | + break |
| 97 | + |
| 98 | + missing_names = set(input_names) - found_names |
| 99 | + if missing_names: |
| 100 | + raise ValueError( |
| 101 | + f"Could not find quantization parameters for input(s): {sorted(missing_names)}. " |
| 102 | + f"Make sure these input names exist in the graph and quantization parameters." |
| 103 | + ) |
| 104 | + |
| 105 | + return quant_args |
| 106 | + |
| 107 | + |
| 108 | +class QuantizedInputWrapper(torch.nn.Module): |
| 109 | + """ |
| 110 | + Wrapper that allows a quantized model to accept quantized inputs. |
| 111 | +
|
| 112 | + If no input_names or quant_args are provided, the wrapper passes inputs |
| 113 | + through unchanged (no dequantization). |
| 114 | +
|
| 115 | + Args: |
| 116 | + module: The quantized GraphModule to wrap. |
| 117 | + input_names: Optional list of input placeholder names in the graph. |
| 118 | + If provided, extracts quant params from graph. |
| 119 | + quant_args: Optional dict mapping input index to (scale, zero_point, qmin, qmax, dtype). |
| 120 | + If provided, uses these directly instead of extracting from graph. |
| 121 | +
|
| 122 | + Example: |
| 123 | + # Extract from graph |
| 124 | + wrapper = QuantizedInputWrapper(quantized_module, input_names=["x"]) |
| 125 | +
|
| 126 | + # Explicit quant args |
| 127 | + wrapper = QuantizedInputWrapper( |
| 128 | + quantized_module, |
| 129 | + quant_args={0: (1/255, 0, 0, 255, torch.uint8)}, |
| 130 | + ) |
| 131 | + """ |
| 132 | + |
| 133 | + def __init__( |
| 134 | + self, |
| 135 | + module: GraphModule, |
| 136 | + input_args: Optional[Union[list[str], dict[int, QuantArgs]]] = None, |
| 137 | + ) -> None: |
| 138 | + super().__init__() |
| 139 | + self.module: GraphModule = module |
| 140 | + self.quant_args: dict[int, QuantArgs] = {} |
| 141 | + |
| 142 | + if input_args is not None: |
| 143 | + logger.warning( |
| 144 | + "Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed." |
| 145 | + "Incorrect quantization parameters can lead to significant accuracy degradation." |
| 146 | + ) |
| 147 | + if isinstance(input_args, list): |
| 148 | + self.quant_args = extract_input_quant_params_from_graph(module, input_args) |
| 149 | + elif isinstance(input_args, dict): |
| 150 | + self.quant_args = input_args |
| 151 | + |
| 152 | + def forward(self, *args: torch.Tensor) -> Any: |
| 153 | + """Run inference, dequantizing configured inputs.""" |
| 154 | + dequantized_args = [] |
| 155 | + for index, node in enumerate(args): |
| 156 | + if index in self.quant_args: |
| 157 | + scale, zp, qmin, qmax, dtype = self.quant_args[index] |
| 158 | + node = torch.ops.quantized_decomposed.dequantize_per_tensor.default( |
| 159 | + node, scale, zp, qmin, qmax, dtype |
| 160 | + ) |
| 161 | + dequantized_args.append(node) |
| 162 | + |
| 163 | + return self.module(*dequantized_args) |
0 commit comments