diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9bdb3871a2..87bb2ba1f6 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1231,6 +1231,7 @@ def choose_qparams_affine( eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype scale_dtype (torch.dtype): dtype for scale Tensor zero_point_dtype (torch.dtype): dtype for zero_point Tensor, defaults to torch.int32 + keepdim (bool): whether to keep dimensions with size 1 in output (aligned with _choose_scale_float8) Now removed params: zero_point_domain (ZeroPointDomain): the domain that zero_point is in, defaults to Integer or None preserve_zero (bool): whether to preserve zero in the quantized Tensor, defaults to True @@ -1532,6 +1533,10 @@ def _choose_qparams_affine( 2. find min_val/max_val based on the dimension for reduction 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` and `zero_point_domain` + + Note: + keepdim defaults to True to align with _choose_scale_float8 behavior. This ensures + scale/zero_point maintain the same rank as input, making it easier to handle downstream. """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) assert mapping_type in [ @@ -1548,6 +1553,8 @@ def _choose_qparams_affine( assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" ) + # Save original input size before reshaping for later use + original_input_size = input.size() shape_for_reduction, reduction_dims = _get_reduction_params( block_size, input.size() ) @@ -1591,6 +1598,15 @@ def _choose_qparams_affine( if zero_point_dtype is None: zero_point_dtype = torch.int32 + # Reshape scale and zero_point to match expected output shape + # This aligns with _choose_scale_float8 behavior + if keepdim: + output_shape = [ + original_input_size[i] // block_size[i] for i in range(len(block_size)) + ] + scale = scale.reshape(output_shape) + zero_point = zero_point.reshape(output_shape) + return scale.to(dtype=scale_dtype, device=input.device), zero_point.to( dtype=zero_point_dtype ) diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index 88ad165ecf..05672b406b 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -229,6 +229,7 @@ def from_hp( quant_min=qmin, quant_max=qmax, zero_point_dtype=torch.int8, + keepdim=True, # Use keepdim=True to get reshaped output matching block structure ) qdata = quantize_affine( hp_tensor, @@ -244,14 +245,8 @@ def from_hp( f"Unsupported IntxChooseQParamsAlgorithm: {intx_choose_qparams_algorithm}" ) - # Reshape scale and zero_point to be compatible with block_size - # This is asserted in IntxUnpackedToInt8Tensor's __init__ - n_blocks = [] - for i in range(len(block_size)): - assert qdata.shape[i] % block_size[i] == 0 - n_blocks.append(qdata.shape[i] // block_size[i]) - scale = scale.reshape(*n_blocks) - zero_point = zero_point.reshape(*n_blocks) + # Note: scale and zero_point already have the correct shape from choose_qparams_affine + # which now uses keepdim=True and reshapes to match block_size expectations return IntxUnpackedToInt8Tensor( qdata=qdata,