4
4
5
5
import torch
6
6
import torch .nn .functional as F
7
- from executorch .backends .xnnpack .utils .utils import is_depthwise_conv
7
+ from executorch .backends .xnnpack .utils .utils import (
8
+ get_groups_from_conv ,
9
+ is_depthwise_conv ,
10
+ )
8
11
from torch ._subclasses import FakeTensor
9
12
from torch .fx import Node
10
13
from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
65
68
return decorator
66
69
67
70
71
+ def change_quantization_config (
72
+ original_qspec ,
73
+ dtype = None ,
74
+ quant_min = None ,
75
+ quant_max = None ,
76
+ qscheme = None ,
77
+ ch_axis = None ,
78
+ is_dynamic = None ,
79
+ observer_or_fake_quant_ctr = None ,
80
+ ):
81
+ return QuantizationSpec (
82
+ dtype = dtype or original_qspec .dtype ,
83
+ quant_min = quant_min or original_qspec .quant_min ,
84
+ quant_max = quant_max or original_qspec .quant_max ,
85
+ qscheme = qscheme or original_qspec .qscheme ,
86
+ ch_axis = ch_axis or original_qspec .ch_axis ,
87
+ is_dynamic = is_dynamic or original_qspec .is_dynamic ,
88
+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr
89
+ or original_qspec .observer_or_fake_quant_ctr ,
90
+ )
91
+
92
+
68
93
def is_relu_node (node : Node ) -> bool :
69
94
"""
70
95
Check if a given node is a relu node
@@ -231,31 +256,44 @@ def _do_annotate_conv(
231
256
if is_relu_node (user ):
232
257
continue
233
258
259
+ # Tracks conditions for whether or not to skip
260
+ skip = False
261
+
234
262
input_qspec_map = {}
235
263
input_act = conv_node .args [0 ]
236
264
assert isinstance (input_act , Node )
237
265
input_qspec_map [input_act ] = get_input_act_qspec (quantization_config )
238
266
239
267
weight = conv_node .args [1 ]
240
268
assert isinstance (weight , Node )
241
- input_qspec_map [weight ] = get_weight_qspec (quantization_config )
269
+ weight_qspec = get_weight_qspec (quantization_config )
270
+ num_groups = get_groups_from_conv (conv_node )
242
271
243
- # Only annotate dynamically quantized conv if it's 2D and not depthwise
244
- if (
272
+ # skip if transposed conv has more than 1 group
273
+ skip = skip or (is_conv_transpose and num_groups != 1 )
274
+ print (f"{ skip } conv transpose and num_groups" )
275
+
276
+ if is_conv_transpose :
277
+ # transposed convs per output channel quantization
278
+ weight_qspec = change_quantization_config (weight_qspec , ch_axis = 1 )
279
+
280
+ input_qspec_map [weight ] = weight_qspec
281
+ is_dynamic = (
245
282
quantization_config
246
283
and quantization_config .input_activation
247
284
and quantization_config .input_activation .is_dynamic
248
- ):
285
+ )
286
+
287
+ # Only annotate dynamically quantized conv if it's 2D and not depthwise
288
+ if is_dynamic :
249
289
weight_val = weight .meta .get ("val" , None )
250
290
weight_shape = getattr (weight_val , "shape" , None )
251
-
252
291
# Skip if not a 4D weight tensor (i.e. not conv2d)
253
- if weight_shape is not None and len (weight_shape ) != 4 :
254
- continue
255
-
292
+ skip = skip or (weight_shape is not None and len (weight_shape ) != 4 )
256
293
# Skip if depthwise (default to groups=1 since it's not an arg)
257
- if is_depthwise_conv (weight_shape , 1 , is_conv_transpose ):
258
- continue
294
+ skip = skip or (
295
+ not is_conv_transpose and is_depthwise_conv (weight_shape , 1 , False )
296
+ )
259
297
260
298
# adding weight node to the partition as well
261
299
partition = [conv_node , conv_node .args [1 ]]
@@ -265,7 +303,7 @@ def _do_annotate_conv(
265
303
input_qspec_map [bias ] = get_bias_qspec (quantization_config )
266
304
partition .append (bias )
267
305
268
- if _is_annotated (partition ):
306
+ if _is_annotated (partition ) or skip :
269
307
continue
270
308
271
309
if filter_fn and any (not filter_fn (n ) for n in partition ):
@@ -311,7 +349,12 @@ def _do_annotate_conv_relu(
311
349
312
350
weight = conv_node .args [1 ]
313
351
assert isinstance (weight , Node )
314
- input_qspec_map [weight ] = get_weight_qspec (quantization_config )
352
+ weight_qspec = get_weight_qspec (quantization_config )
353
+ groups = get_groups_from_conv (conv_node )
354
+ if is_conv_transpose :
355
+ # transposed convs per output channel quantization
356
+ weight_qspec = change_quantization_config (weight_qspec , ch_axis = 1 )
357
+ input_qspec_map [weight ] = weight_qspec
315
358
316
359
# adding weight node to the partition as well
317
360
partition = [relu_node , conv_node , conv_node .args [1 ]]
@@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
323
366
if _is_annotated (partition ):
324
367
continue
325
368
369
+ if is_conv_transpose and groups != 1 :
370
+ continue
371
+
326
372
if filter_fn and any (not filter_fn (n ) for n in partition ):
327
373
continue
328
374
0 commit comments