@@ -33,6 +33,7 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
33
33
34
34
def build_model (
35
35
self ,
36
+ op_type : str ,
36
37
input_shape : ir .Shape ,
37
38
weight_shape : typing .Sequence [int ],
38
39
pad_inputs : typing .Sequence [ir .TensorProtocol | ir .Value | None ],
@@ -57,14 +58,17 @@ def build_model(
57
58
raise ValueError (f"Unsupported type for pad input ({ x } ): { type (x )} ." )
58
59
59
60
# Register operations in the tape
60
- x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (ir .DataType .FLOAT ))
61
+ idtype = ir .DataType .UINT8 if op_type == "ConvInteger" else ir .DataType .FLOAT
62
+ x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (idtype ))
61
63
y = tape .op ("Pad" , inputs = [x , * pad_inputs ], attributes = pad_attributes )
62
64
y = tape .op (
63
- "Conv" ,
65
+ op_type ,
64
66
inputs = [y , self .get_conv_weights (weight_shape , tape )],
65
67
attributes = conv_attributes ,
66
68
output = ir .Input ("Y" , shape = output_shape , type = ir .TensorType (x .dtype )),
67
69
)
70
+ if op_type == "ConvInteger" :
71
+ y .dtype = ir .DataType .INT32
68
72
69
73
# Build the model
70
74
ir_model = ir .Model (
@@ -101,6 +105,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
101
105
if axes is not None :
102
106
pad_inputs .append (axes )
103
107
base_model = self .build_model (
108
+ op_type = "Conv" ,
104
109
input_shape = ir .Shape (("N" , 32 , 14 , 16 )),
105
110
weight_shape = (10 , 32 , 3 , 3 ),
106
111
pad_inputs = pad_inputs ,
@@ -190,6 +195,7 @@ def test_unsupported_fuse_pad_into_conv(
190
195
self , mode , pads , const_value , axes , auto_pad , err_msg
191
196
):
192
197
base_model = self .build_model (
198
+ op_type = "Conv" ,
193
199
input_shape = ir .Shape (("N" , 32 , 14 , 16 , 12 )),
194
200
weight_shape = (10 , 32 , 3 , 4 , 5 ),
195
201
pad_inputs = [pads , const_value , axes ],
@@ -208,5 +214,51 @@ def test_unsupported_fuse_pad_into_conv(
208
214
self .assertRegex (tracer_match .match_result .reason , err_msg )
209
215
210
216
217
+ class FusePadConvIntegerTest (FusePadConvBaseTest ):
218
+ def get_conv_weights (self , shape : typing .Sequence [int ], tape : ir .tape .Tape = None ):
219
+ w = ir .tensor (self .rng .integers (0 , 256 , shape ).astype ("uint8" ), name = "W" )
220
+ if tape is not None :
221
+ w = tape .initializer (w )
222
+ return w
223
+
224
+ @parameterized .parameterized .expand (
225
+ [
226
+ (pad_pads , const_value , axes , conv_pads )
227
+ for pad_pads , axes , conv_pads in [
228
+ ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ]),
229
+ ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None ),
230
+ ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
231
+ ]
232
+ for const_value in [None , ir .tensor (np .array ([0 ], "uint8" ), name = "const_value" )]
233
+ ]
234
+ )
235
+ def test_fuse_pad_into_conv_integer (self , pad_pads , const_value , axes , conv_pads ):
236
+ pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
237
+ if const_value is not None or axes is not None :
238
+ pad_inputs .append (const_value )
239
+ if axes is not None :
240
+ pad_inputs .append (axes )
241
+ base_model = self .build_model (
242
+ op_type = "ConvInteger" ,
243
+ input_shape = ir .Shape (("N" , 24 , 19 , 23 )),
244
+ weight_shape = (8 , 24 , 3 , 3 ),
245
+ pad_inputs = pad_inputs ,
246
+ conv_attributes = {"pads" : conv_pads },
247
+ )
248
+ updated_model = _clone_model (base_model )
249
+
250
+ # Apply rule
251
+ count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
252
+
253
+ # Check that Pad was fused
254
+ self .assertEqual (count , 1 )
255
+ self .assertEqual (updated_model .graph .num_nodes (), 1 )
256
+ onnx_checker .CheckerPass (True )(updated_model )
257
+
258
+ # Check inference
259
+ inputs = self .rng .integers (0 , 255 , (1 , 24 , 19 , 23 ), dtype = "uint8" )
260
+ testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
261
+
262
+
211
263
if __name__ == "__main__" :
212
264
unittest .main ()
0 commit comments