diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 9bdb3871a2..c94794f69f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1553,8 +1553,12 @@ def _choose_qparams_affine( ) input = input.view(shape_for_reduction) - min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim) - max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim) + if reduction_dims: + min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim) + max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim) + else: + min_val = input + max_val = input min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val))