Skip to content

Commit d7bc3bc

Browse files
committed
[ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4
Pull Request resolved: #14173 * Prefer chanelwise over groupwise when possible for perf and for int8 which doesn't have groupwise support * Fix bug / improve behavior for affine q/dq with gs == k for per_channel * refactor is_per_channel_group state variable * add QuantParams.__str__() TODO - improve affine quant primitives - T237476295 ghstack-source-id: 309161539 @exported-using-ghexport Differential Revision: [D82060758](https://our.internmc.facebook.com/intern/diff/D82060758/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D82060758/)!
1 parent 6b9c0a6 commit d7bc3bc

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def get_per_channel_dtype(
232232
if quant_params.dtype == torch.int32:
233233
return XNNDatatype.xnn_datatype_qcint32
234234
elif quant_params.dtype == torch.int8:
235-
if quant_params.is_per_channel_group:
235+
if quant_params.per_channel_group:
236236
# 4-bit per channel group quantized weights
237237
# No 8-bit support yet
238238
assert (
@@ -282,7 +282,7 @@ def get_quant_params(
282282
buffer_idx = len(xnn_graph.constant_data)
283283
num_scales = scale.numel()
284284

285-
if quant_params.is_per_channel_group:
285+
if quant_params.per_channel_group:
286286
scale = scale.to(torch.bfloat16)
287287

288288
num_bytes = scale.untyped_storage().nbytes()
@@ -300,7 +300,7 @@ def get_quant_params(
300300
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
301301
)
302302

303-
if quant_params.is_per_channel_group:
303+
if quant_params.per_channel_group:
304304
return PerChannelGroupQuant(
305305
scale=[],
306306
channel_dim=quant_params.axis,
@@ -335,7 +335,7 @@ def _check_per_channel_group_params(
335335
) -> None:
336336
# Make sure things are lining up for per_channel_group quantization case
337337
# Has to be done this late because we don't have clean access to the actual tensor
338-
assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
338+
assert quant_params.per_channel_group, "Not per_channel_group quantization"
339339
# linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
340340
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
341341
assert (

backends/xnnpack/operators/quant_params.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,44 @@ def __init__(
8989
# Groupwise quantization for weight
9090
self.per_channel_group = False
9191
self.group_size = group_size
92+
93+
tensor = q_input.meta["val"]
94+
9295
if self.group_size > 0:
9396
assert (
9497
self.per_channel is True
9598
), "Only per channel quantization supports groupwise quantization"
99+
assert (
100+
self.axis == 0,
101+
"Only axis 0 is supported for per channel groupwise quant",
102+
)
96103
assert (
97104
cast(torch.Tensor, scale).ndim == 2
98105
), "Scale must be 2D for per channel groupwise quant"
99-
self.per_channel_group = True
100-
assert group_size > 0, "Group size must be greater than 0"
101-
self.is_per_channel_group = self.per_channel and self.group_size > 0
106+
# Assumed scale shape - [out_channels, in_channels/group_size]
107+
input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size
108+
# 2d weight tensor shape - [out_channels, in_channels]
109+
assert (
110+
tensor.shape[1] == input_channels,
111+
"Invalid input channels for groupwise quant",
112+
)
113+
# Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
114+
# int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
115+
self.per_channel_group = (
116+
self.group_size <= input_channels
117+
if self.is_qc4w
118+
else self.group_size < input_channels
119+
)
102120

103-
if per_channel and not self.is_per_channel_group:
104-
tensor = q_input.meta["val"]
121+
if not self.per_channel_group:
122+
if cast(torch.Tensor, scale).ndim == 2:
123+
# TODO: don't reshape scale for per_channel cases
124+
assert (
125+
cast(torch.Tensor, scale).shape[1] == 1
126+
), "Invalid scale shape for per channel quantization"
127+
scale = cast(torch.Tensor, scale).squeeze(1)
128+
129+
if per_channel and not self.per_channel_group:
105130
assert (
106131
tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
107132
), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
@@ -110,6 +135,39 @@ def __init__(
110135
tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0]
111136
), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}"
112137

138+
def __str__(self) -> str:
139+
"""String representation of QuantParams for debugging and logging."""
140+
assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor)
141+
scale_str = (
142+
f"{self.scale}"
143+
if isinstance(self.scale, float)
144+
else f"tensor{tuple(self.scale.shape)}"
145+
)
146+
assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor)
147+
zp_str = (
148+
f"{self.zp}"
149+
if isinstance(self.zp, float)
150+
else f"tensor{tuple(self.zp.shape)}"
151+
)
152+
153+
return (
154+
f"QuantParams("
155+
f"per_channel={self.per_channel}, "
156+
f"per_channel_group={self.per_channel_group}, "
157+
f"scale={scale_str}, "
158+
f"zp={zp_str}, "
159+
f"axis={self.axis}, "
160+
f"dtype={self.dtype}, "
161+
f"qmin={self.qmin}, "
162+
f"qmax={self.qmax}, "
163+
f"is_dynamic={self.is_dynamic}, "
164+
f"is_input={self.is_input}, "
165+
f"is_output={self.is_output}, "
166+
f"group_size={self.group_size}, "
167+
f"is_qc4w={self.is_qc4w}"
168+
f")"
169+
)
170+
113171
def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
114172
# Do nothing if already quantized by the Quantizer
115173
if tensor.dtype == self.dtype:

0 commit comments

Comments
 (0)