From 50994a520ee1381d4193bb43f63e6db3c46fb93c Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 10 Sep 2025 12:55:06 -0700 Subject: [PATCH 1/2] [ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4 * Prefer chanelwise over groupwise when possible for perf and for int8 which doesn't have groupwise support * Fix bug / improve behavior for affine q/dq with gs == k for per_channel * refactor is_per_channel_group state variable * add QuantParams.__str__() TODO - improve affine quant primitives - T237476295 Differential Revision: [D82060758](https://our.internmc.facebook.com/intern/diff/D82060758/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D82060758/)! [ghstack-poisoned] --- backends/xnnpack/operators/node_visitor.py | 8 ++-- backends/xnnpack/operators/quant_params.py | 54 +++++++++++++++++++--- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 6a055c9413f..68226644859 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -232,7 +232,7 @@ def get_per_channel_dtype( if quant_params.dtype == torch.int32: return XNNDatatype.xnn_datatype_qcint32 elif quant_params.dtype == torch.int8: - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: # 4-bit per channel group quantized weights # No 8-bit support yet assert ( @@ -282,7 +282,7 @@ def get_quant_params( buffer_idx = len(xnn_graph.constant_data) num_scales = scale.numel() - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: scale = scale.to(torch.bfloat16) num_bytes = scale.untyped_storage().nbytes() @@ -300,7 +300,7 @@ def get_quant_params( scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT ) - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: return PerChannelGroupQuant( scale=[], channel_dim=quant_params.axis, @@ -335,7 +335,7 @@ def _check_per_channel_group_params( ) -> None: # Make sure things are lining up for per_channel_group quantization case # Has to be done this late because we don't have clean access to the actual tensor - assert quant_params.is_per_channel_group, "Not per_channel_group quantization" + assert quant_params.per_channel_group, "Not per_channel_group quantization" # linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0 num_groups = cast(torch.Tensor, quant_params.scale).shape[1] assert ( diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index 88a1f660f0e..ab7141e3ddd 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -89,19 +89,36 @@ def __init__( # Groupwise quantization for weight self.per_channel_group = False self.group_size = group_size + + tensor = q_input.meta["val"] + if self.group_size > 0: assert ( self.per_channel is True ), "Only per channel quantization supports groupwise quantization" + assert ( + self.axis == 0, "Only axis 0 is supported for per channel groupwise quant" + ) assert ( cast(torch.Tensor, scale).ndim == 2 ), "Scale must be 2D for per channel groupwise quant" - self.per_channel_group = True - assert group_size > 0, "Group size must be greater than 0" - self.is_per_channel_group = self.per_channel and self.group_size > 0 - - if per_channel and not self.is_per_channel_group: - tensor = q_input.meta["val"] + # Assumed scale shape - [out_channels, in_channels/group_size] + input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size + # 2d weight tensor shape - [out_channels, in_channels] + assert ( + tensor.shape[1] == input_channels, "Invalid input channels for groupwise quant" + ) + # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only + # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack. + self.per_channel_group = self.group_size <= input_channels if self.is_qc4w else self.group_size < input_channels + + if not self.per_channel_group: + if cast(torch.Tensor, scale).ndim == 2: + # TODO: don't reshape scale for per_channel cases + assert (cast(torch.Tensor, scale).shape[1] == 1), "Invalid scale shape for per channel quantization" + scale = cast(torch.Tensor, scale).squeeze(1) + + if per_channel and not self.per_channel_group: assert ( tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0] ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}" @@ -110,6 +127,31 @@ def __init__( tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0] ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}" + def __str__(self) -> str: + """String representation of QuantParams for debugging and logging.""" + assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor) + scale_str = f"{self.scale}" if isinstance(self.scale, float) else f"tensor{tuple(self.scale.shape)}" + assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor) + zp_str = f"{self.zp}" if isinstance(self.zp, float) else f"tensor{tuple(self.zp.shape)}" + + return ( + f"QuantParams(" + f"per_channel={self.per_channel}, " + f"per_channel_group={self.per_channel_group}, " + f"scale={scale_str}, " + f"zp={zp_str}, " + f"axis={self.axis}, " + f"dtype={self.dtype}, " + f"qmin={self.qmin}, " + f"qmax={self.qmax}, " + f"is_dynamic={self.is_dynamic}, " + f"is_input={self.is_input}, " + f"is_output={self.is_output}, " + f"group_size={self.group_size}, " + f"is_qc4w={self.is_qc4w}" + f")" + ) + def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: # Do nothing if already quantized by the Quantizer if tensor.dtype == self.dtype: From e6f6047353699618eae94195ad75ca4a044d4a85 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 10 Sep 2025 13:00:55 -0700 Subject: [PATCH 2/2] Update on "[ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4" * Prefer chanelwise over groupwise when possible for perf and for int8 which doesn't have groupwise support * Fix bug / improve behavior for affine q/dq with gs == k for per_channel * refactor is_per_channel_group state variable * add QuantParams.__str__() TODO - improve affine quant primitives - T237476295 Differential Revision: [D82060758](https://our.internmc.facebook.com/intern/diff/D82060758/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D82060758/)! [ghstack-poisoned] --- backends/xnnpack/operators/quant_params.py | 36 ++++++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index ab7141e3ddd..a2b7c555faa 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -89,15 +89,16 @@ def __init__( # Groupwise quantization for weight self.per_channel_group = False self.group_size = group_size - + tensor = q_input.meta["val"] - + if self.group_size > 0: assert ( self.per_channel is True ), "Only per channel quantization supports groupwise quantization" assert ( - self.axis == 0, "Only axis 0 is supported for per channel groupwise quant" + self.axis == 0, + "Only axis 0 is supported for per channel groupwise quant", ) assert ( cast(torch.Tensor, scale).ndim == 2 @@ -106,16 +107,23 @@ def __init__( input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size # 2d weight tensor shape - [out_channels, in_channels] assert ( - tensor.shape[1] == input_channels, "Invalid input channels for groupwise quant" - ) + tensor.shape[1] == input_channels, + "Invalid input channels for groupwise quant", + ) # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack. - self.per_channel_group = self.group_size <= input_channels if self.is_qc4w else self.group_size < input_channels + self.per_channel_group = ( + self.group_size <= input_channels + if self.is_qc4w + else self.group_size < input_channels + ) if not self.per_channel_group: if cast(torch.Tensor, scale).ndim == 2: # TODO: don't reshape scale for per_channel cases - assert (cast(torch.Tensor, scale).shape[1] == 1), "Invalid scale shape for per channel quantization" + assert ( + cast(torch.Tensor, scale).shape[1] == 1 + ), "Invalid scale shape for per channel quantization" scale = cast(torch.Tensor, scale).squeeze(1) if per_channel and not self.per_channel_group: @@ -130,10 +138,18 @@ def __init__( def __str__(self) -> str: """String representation of QuantParams for debugging and logging.""" assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor) - scale_str = f"{self.scale}" if isinstance(self.scale, float) else f"tensor{tuple(self.scale.shape)}" + scale_str = ( + f"{self.scale}" + if isinstance(self.scale, float) + else f"tensor{tuple(self.scale.shape)}" + ) assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor) - zp_str = f"{self.zp}" if isinstance(self.zp, float) else f"tensor{tuple(self.zp.shape)}" - + zp_str = ( + f"{self.zp}" + if isinstance(self.zp, float) + else f"tensor{tuple(self.zp.shape)}" + ) + return ( f"QuantParams(" f"per_channel={self.per_channel}, "