diff --git a/backends/xnnpack/test/models/deeplab_v3.py b/backends/xnnpack/test/models/deeplab_v3.py index ccaccb898d2..c5f6bfe17bc 100644 --- a/backends/xnnpack/test/models/deeplab_v3.py +++ b/backends/xnnpack/test/models/deeplab_v3.py @@ -36,6 +36,5 @@ def test_fp32_dl3(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/edsr.py b/backends/xnnpack/test/models/edsr.py index d748e35bb74..ca080b20b49 100644 --- a/backends/xnnpack/test/models/edsr.py +++ b/backends/xnnpack/test/models/edsr.py @@ -25,8 +25,7 @@ def test_fp32_edsr(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_edsr(self): @@ -38,6 +37,5 @@ def test_qs8_edsr(self): .partition() .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/emformer_rnnt.py b/backends/xnnpack/test/models/emformer_rnnt.py index 3728c9b07c9..3992c828964 100644 --- a/backends/xnnpack/test/models/emformer_rnnt.py +++ b/backends/xnnpack/test/models/emformer_rnnt.py @@ -21,8 +21,8 @@ def __init__(self): self.rnnt = decoder.model class Joiner(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.join(*predict_inputs) + def forward(self, a, b, c, d): + return self.rnnt.join(a, b, c, d) def get_example_inputs(self): join_inputs = ( @@ -31,7 +31,7 @@ def get_example_inputs(self): torch.rand([1, 128, 1024]), torch.tensor([128]), ) - return (join_inputs,) + return join_inputs def test_fp32_emformer_joiner(self): joiner = self.Joiner() @@ -43,21 +43,19 @@ def test_fp32_emformer_joiner(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Predictor(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.predict(*predict_inputs) + def forward(self, a, b): + return self.rnnt.predict(a, b, None) def get_example_inputs(self): predict_inputs = ( torch.zeros([1, 128], dtype=int), torch.tensor([128], dtype=int), - None, ) - return (predict_inputs,) + return predict_inputs @unittest.skip("T183426271") def test_fp32_emformer_predictor(self): @@ -70,20 +68,19 @@ def test_fp32_emformer_predictor(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Transcriber(EmformerRnnt): - def forward(self, predict_inputs): - return self.rnnt.transcribe(*predict_inputs) + def forward(self, a, b): + return self.rnnt.transcribe(a, b) def get_example_inputs(self): transcribe_inputs = ( torch.randn(1, 128, 80), torch.tensor([128]), ) - return (transcribe_inputs,) + return transcribe_inputs def test_fp32_emformer_transcriber(self): transcriber = self.Transcriber() @@ -95,6 +92,5 @@ def test_fp32_emformer_transcriber(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/inception_v3.py b/backends/xnnpack/test/models/inception_v3.py index 58839014557..b861afc5cd5 100644 --- a/backends/xnnpack/test/models/inception_v3.py +++ b/backends/xnnpack/test/models/inception_v3.py @@ -42,8 +42,7 @@ def test_fp32_ic3(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_ic3(self): @@ -63,6 +62,5 @@ def test_qs8_ic3(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/inception_v4.py b/backends/xnnpack/test/models/inception_v4.py index 534fb90ad6c..528512c82f2 100644 --- a/backends/xnnpack/test/models/inception_v4.py +++ b/backends/xnnpack/test/models/inception_v4.py @@ -39,8 +39,7 @@ def test_fp32_ic4(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_ic4(self): @@ -60,6 +59,5 @@ def test_qs8_ic4(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/llama2_et_example.py b/backends/xnnpack/test/models/llama2_et_example.py index 46dae356cd8..4716f2d6a95 100644 --- a/backends/xnnpack/test/models/llama2_et_example.py +++ b/backends/xnnpack/test/models/llama2_et_example.py @@ -45,6 +45,5 @@ def _test(self, dtype: torch.dtype = torch.float): .dump_artifact() .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=5e-2) + .run_method_and_compare_outputs(atol=5e-2) ) diff --git a/backends/xnnpack/test/models/mobilebert.py b/backends/xnnpack/test/models/mobilebert.py index bf6b2dfc408..df66ffd4507 100644 --- a/backends/xnnpack/test/models/mobilebert.py +++ b/backends/xnnpack/test/models/mobilebert.py @@ -38,6 +38,5 @@ def test_fp32_mobilebert(self): .check_not(list(self.supported_ops)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/mobilenet_v2.py b/backends/xnnpack/test/models/mobilenet_v2.py index dbd9bc744b4..53bcedd0a90 100644 --- a/backends/xnnpack/test/models/mobilenet_v2.py +++ b/backends/xnnpack/test/models/mobilenet_v2.py @@ -29,9 +29,15 @@ class TestMobileNetV2(unittest.TestCase): } def test_fp32_mv2(self): + dynamic_shapes = ( + { + 2: torch.export.Dim("height", min=224, max=455), + 3: torch.export.Dim("width", min=224, max=455), + }, + ) ( - Tester(self.mv2, self.model_inputs) + Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes) .export() .to_edge() .check(list(self.all_operators)) @@ -40,8 +46,7 @@ def test_fp32_mv2(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs(num_runs=10) ) def test_qs8_mv2(self): @@ -50,8 +55,15 @@ def test_qs8_mv2(self): "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", } + dynamic_shapes = ( + { + 2: torch.export.Dim("height", min=224, max=455), + 3: torch.export.Dim("width", min=224, max=455), + }, + ) + ( - Tester(self.mv2, self.model_inputs) + Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes) .quantize(Quantize(calibrate=False)) .export() .to_edge() @@ -61,6 +73,5 @@ def test_qs8_mv2(self): .check_not(list(ops_after_quantization)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs(num_runs=10) ) diff --git a/backends/xnnpack/test/models/mobilenet_v3.py b/backends/xnnpack/test/models/mobilenet_v3.py index 20d04b119e1..3da2e3bf42c 100644 --- a/backends/xnnpack/test/models/mobilenet_v3.py +++ b/backends/xnnpack/test/models/mobilenet_v3.py @@ -16,6 +16,12 @@ class TestMobileNetV3(unittest.TestCase): mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True) mv3 = mv3.eval() model_inputs = (torch.ones(1, 3, 224, 224),) + dynamic_shapes = ( + { + 2: torch.export.Dim("height", min=224, max=455), + 3: torch.export.Dim("width", min=224, max=455), + }, + ) all_operators = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", @@ -33,7 +39,7 @@ class TestMobileNetV3(unittest.TestCase): def test_fp32_mv3(self): ( - Tester(self.mv3, self.model_inputs) + Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes) .export() .to_edge() .check(list(self.all_operators)) @@ -42,8 +48,7 @@ def test_fp32_mv3(self): .check_not(list(self.all_operators)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs(num_runs=5) ) def test_qs8_mv3(self): @@ -53,7 +58,7 @@ def test_qs8_mv3(self): ops_after_lowering = self.all_operators ( - Tester(self.mv3, self.model_inputs) + Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes) .quantize(Quantize(calibrate=False)) .export() .to_edge() @@ -63,6 +68,5 @@ def test_qs8_mv3(self): .check_not(list(ops_after_lowering)) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs(num_runs=5) ) diff --git a/backends/xnnpack/test/models/resnet.py b/backends/xnnpack/test/models/resnet.py index 73e68c855e9..06c889fc179 100644 --- a/backends/xnnpack/test/models/resnet.py +++ b/backends/xnnpack/test/models/resnet.py @@ -14,29 +14,63 @@ class TestResNet18(unittest.TestCase): - def test_fp32_resnet18(self): - inputs = (torch.ones(1, 3, 224, 224),) + inputs = (torch.ones(1, 3, 224, 224),) + dynamic_shapes = ( + { + 2: torch.export.Dim("height", min=224, max=455), + 3: torch.export.Dim("width", min=224, max=455), + }, + ) + + class DynamicResNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torchvision.models.resnet18() + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, + size=(224, 224), + mode="bilinear", + align_corners=True, + antialias=False, + ) + return self.model(x) + + def _test_exported_resnet(self, tester): ( - Tester(torchvision.models.resnet18(), inputs) - .export() + tester.export() .to_edge() .partition() + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_convolution_default", + "executorch_exir_dialects_edge__ops_aten_mean_dim", + ] + ) + .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) + def test_fp32_resnet18(self): + self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs)) + def test_qs8_resnet18(self): - inputs = (torch.ones(1, 3, 224, 224),) - ( - Tester(torchvision.models.resnet18(), inputs) - .quantize(Quantize(calibrate=False)) - .export() - .to_edge() - .partition() - .to_executorch() - .serialize() - .run_method() - .compare_outputs() + quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize( + Quantize(calibrate=False) + ) + self._test_exported_resnet(quantized_tester) + + def test_fp32_resnet18_dynamic(self): + self._test_exported_resnet( + Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes) + ) + + def test_qs8_resnet18_dynamic(self): + self._test_exported_resnet( + Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize( + Quantize(calibrate=False) + ) ) diff --git a/backends/xnnpack/test/models/torchvision_vit.py b/backends/xnnpack/test/models/torchvision_vit.py index 226cc73f401..e4b387e0f79 100644 --- a/backends/xnnpack/test/models/torchvision_vit.py +++ b/backends/xnnpack/test/models/torchvision_vit.py @@ -15,6 +15,29 @@ class TestViT(unittest.TestCase): vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1") vit = vit.eval() model_inputs = (torch.ones(1, 3, 224, 224),) + dynamic_shapes = ( + { + 2: torch.export.Dim("height", min=224, max=455), + 3: torch.export.Dim("width", min=224, max=455), + }, + ) + + class DynamicViT(torch.nn.Module): + def __init__(self): + super().__init__() + self.vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1") + self.vit = self.vit.eval() + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, + size=(224, 224), + mode="bilinear", + align_corners=True, + antialias=False, + ) + return self.vit(x) + all_operators = { "executorch_exir_dialects_edge__ops_aten_expand_copy_default", "executorch_exir_dialects_edge__ops_aten_cat_default", @@ -34,7 +57,8 @@ class TestViT(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_bmm_default", } - def test_fp32_vit(self): + def _test_exported_vit(self, tester, check_nots=None): + check_nots = check_nots or [] lowerable_xnn_operators = self.all_operators - { "executorch_exir_dialects_edge__ops_aten_expand_copy_default", "executorch_exir_dialects_edge__ops_aten_gelu_default", @@ -48,15 +72,33 @@ def test_fp32_vit(self): "executorch_exir_dialects_edge__ops_aten_bmm_default", } ( - Tester(self.vit, self.model_inputs) - .export() + tester.export() .to_edge() .check(list(self.all_operators)) .partition() .check(["torch.ops.higher_order.executorch_call_delegate"]) .check_not(list(lowerable_xnn_operators)) + .check_not(check_nots) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() + ) + + def test_fp32_vit(self): + self._test_exported_vit(Tester(self.vit, self.model_inputs)) + + def test_dynamic_vit(self): + bilinear_ops = { + "executorch_exir_dialects_edge__ops_aten_sub_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_index_Tensor", + "executorch_exir_dialects_edge__ops_aten_arange_start_step", + "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_clamp_default", + } + + self._test_exported_vit( + Tester(self.DynamicViT(), self.model_inputs, self.dynamic_shapes), + bilinear_ops, ) diff --git a/backends/xnnpack/test/models/very_big_model.py b/backends/xnnpack/test/models/very_big_model.py index 2200b50a6b2..f3f06380414 100644 --- a/backends/xnnpack/test/models/very_big_model.py +++ b/backends/xnnpack/test/models/very_big_model.py @@ -39,6 +39,5 @@ def test_very_big_model(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/models/w2l.py b/backends/xnnpack/test/models/w2l.py index 10d7ca15b08..c95fc29d8cc 100644 --- a/backends/xnnpack/test/models/w2l.py +++ b/backends/xnnpack/test/models/w2l.py @@ -34,8 +34,7 @@ def test_fp32_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_w2l(self): @@ -54,6 +53,5 @@ def test_qs8_w2l(self): .check(["torch.ops.higher_order.executorch_call_delegate"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/abs.py b/backends/xnnpack/test/ops/abs.py index c71fe5ab4e0..2906654dfb7 100644 --- a/backends/xnnpack/test/ops/abs.py +++ b/backends/xnnpack/test/ops/abs.py @@ -31,8 +31,7 @@ def _test_abs(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_abs_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_abs(self): diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index 3a56e0f4c6a..8b0d0c6234d 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -54,8 +54,7 @@ def _test_add(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_add(self): @@ -79,8 +78,7 @@ def test_fp32_add_constant(self): .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_constant(self): @@ -121,8 +119,7 @@ def test_qs8_add(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add2(self): @@ -145,8 +142,7 @@ def test_qs8_add2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add3(self): @@ -169,8 +165,7 @@ def test_qs8_add3(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class AddRelu(torch.nn.Module): @@ -194,8 +189,7 @@ def test_fp32_add_relu(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_relu(self): @@ -214,8 +208,7 @@ def test_qs8_add_relu(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_add_relu_seq(self): @@ -261,6 +254,5 @@ def forward(self, x, z): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/avgpool2d.py b/backends/xnnpack/test/ops/avgpool2d.py index 2dd46932988..edb92d09a35 100644 --- a/backends/xnnpack/test/ops/avgpool2d.py +++ b/backends/xnnpack/test/ops/avgpool2d.py @@ -42,8 +42,7 @@ def _test_argpool2d(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_avgpool2d(self): diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index 2e80eaf2bc5..ab9d3d3c11d 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -87,8 +87,7 @@ def test_fp32_static_resize_bilinear2d(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): @@ -103,8 +102,7 @@ def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_static_resize_bilinear2d_antialiased(self): diff --git a/backends/xnnpack/test/ops/cat.py b/backends/xnnpack/test/ops/cat.py index 8cb9b760b0d..85c5b51a2c7 100644 --- a/backends/xnnpack/test/ops/cat.py +++ b/backends/xnnpack/test/ops/cat.py @@ -11,16 +11,31 @@ class TestCat(unittest.TestCase): - class Cat(torch.nn.Module): - def forward(self, xs): + class Cat2(torch.nn.Module): + def forward(self, arg1, arg2): + xs = [arg1, arg2] x = torch.cat(xs) return x + x # Quantize by propagation. - class Cat2(torch.nn.Module): - def forward(self, xs): - return torch.cat(xs) + class Cat3(torch.nn.Module): + def forward(self, arg1, arg2, arg3): + xs = [arg1, arg2, arg3] + x = torch.cat(xs) + return x + x # Quantize by propagation. + + class Cat4(torch.nn.Module): + def forward(self, arg1, arg2, arg3, arg4): + xs = [arg1, arg2, arg3, arg4] + x = torch.cat(xs) + return x + x # Quantize by propagation. - def _test_cat(self, module, inputs, quant=False, quant_ops=2): + class Cat5(torch.nn.Module): + def forward(self, arg1, arg2, arg3, arg4, arg5): + xs = [arg1, arg2, arg3, arg4, arg5] + x = torch.cat(xs) + return x + x # Quantize by propagation. + + def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): tester = Tester(module, inputs) if quant: @@ -36,7 +51,7 @@ def _test_cat(self, module, inputs, quant=False, quant_ops=2): # Q/DQ pair for each input and quantized op. For most tests, there are # two quantized ops - cat and add. torch.ops.quantized_decomposed.quantize_per_tensor.default: ( - len(inputs[0]) + quant_ops + cat_num + quant_ops ) } ) @@ -55,8 +70,7 @@ def _test_cat(self, module, inputs, quant=False, quant_ops=2): .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_cat2(self): @@ -64,10 +78,8 @@ def test_fp16_cat2(self): Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), ) self._test_cat(self.Cat2(), inputs) @@ -76,81 +88,71 @@ def test_fp16_cat3(self): Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - torch.ones(2, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), + torch.ones(2, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat3(), inputs) def test_fp16_cat4(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( - ( - torch.ones(1, 2, 3).to(torch.float16), - torch.ones(3, 2, 3).to(torch.float16), - torch.ones(2, 2, 3).to(torch.float16), - torch.ones(5, 2, 3).to(torch.float16), - ), + torch.ones(1, 2, 3).to(torch.float16), + torch.ones(3, 2, 3).to(torch.float16), + torch.ones(2, 2, 3).to(torch.float16), + torch.ones(5, 2, 3).to(torch.float16), ) - self._test_cat(self.Cat2(), inputs) + self._test_cat(self.Cat4(), inputs) def test_fp32_cat2(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3)),) - self._test_cat(self.Cat(), inputs) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3)) + self._test_cat(self.Cat2(), inputs) def test_fp32_cat3(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)),) - self._test_cat(self.Cat(), inputs) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)) + self._test_cat(self.Cat3(), inputs) def test_fp32_cat4(self): inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), ) - self._test_cat(self.Cat(), inputs) + self._test_cat(self.Cat4(), inputs) def test_qs8_cat2(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3)),) - self._test_cat(self.Cat(), inputs, quant=True) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3)) + self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) def test_qs8_cat3(self): - inputs = ((torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)),) - self._test_cat(self.Cat(), inputs, quant=True) + inputs = (torch.ones(1, 2, 3), torch.ones(3, 2, 3), torch.ones(2, 2, 3)) + self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) def test_qs8_cat4(self): inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), ) - self._test_cat(self.Cat(), inputs, quant=True) + self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) def test_fp32_cat_unsupported(self): """ XNNPACK only supports concatenating up to 4 values, so it should not delegate here. """ inputs = ( - ( - torch.ones(1, 2, 3), - torch.ones(3, 2, 3), - torch.ones(2, 2, 3), - torch.ones(5, 2, 3), - torch.ones(1, 2, 3), - ), + torch.ones(1, 2, 3), + torch.ones(3, 2, 3), + torch.ones(2, 2, 3), + torch.ones(5, 2, 3), + torch.ones(1, 2, 3), ) ( - Tester(self.Cat(), inputs) + Tester(self.Cat5(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge() diff --git a/backends/xnnpack/test/ops/ceil.py b/backends/xnnpack/test/ops/ceil.py index 853de03ff1d..8d59f3b35d7 100644 --- a/backends/xnnpack/test/ops/ceil.py +++ b/backends/xnnpack/test/ops/ceil.py @@ -31,8 +31,7 @@ def _test_ceil(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_ceil_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_ceil(self): diff --git a/backends/xnnpack/test/ops/clamp.py b/backends/xnnpack/test/ops/clamp.py index 6ffaed3fe1b..c52fd011f8b 100644 --- a/backends/xnnpack/test/ops/clamp.py +++ b/backends/xnnpack/test/ops/clamp.py @@ -33,8 +33,7 @@ def _test_clamp(self, module, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_clamp(self): @@ -77,6 +76,5 @@ def test_qs8_clamp(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/conv1d.py b/backends/xnnpack/test/ops/conv1d.py index 604e37c724c..50f9aa3a996 100644 --- a/backends/xnnpack/test/ops/conv1d.py +++ b/backends/xnnpack/test/ops/conv1d.py @@ -97,8 +97,7 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False): .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_conv1d(self): diff --git a/backends/xnnpack/test/ops/conv2d.py b/backends/xnnpack/test/ops/conv2d.py index 3eb80072a68..9a2bb25dc8d 100644 --- a/backends/xnnpack/test/ops/conv2d.py +++ b/backends/xnnpack/test/ops/conv2d.py @@ -152,8 +152,7 @@ def _test( .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() - .run_method() - .compare_outputs(qtol=1) + .run_method_and_compare_outputs(qtol=1) ) def test_fp16_conv2d(self) -> None: diff --git a/backends/xnnpack/test/ops/div.py b/backends/xnnpack/test/ops/div.py index 007122db981..2882c59b875 100644 --- a/backends/xnnpack/test/ops/div.py +++ b/backends/xnnpack/test/ops/div.py @@ -39,8 +39,7 @@ def _test_div(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_div(self): @@ -64,6 +63,5 @@ def test_fp32_div_single_input(self): .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/elu.py b/backends/xnnpack/test/ops/elu.py index f1f8d7628a6..89fef6f9d4b 100644 --- a/backends/xnnpack/test/ops/elu.py +++ b/backends/xnnpack/test/ops/elu.py @@ -39,8 +39,7 @@ def _test_elu(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171810227 - Missing recomposition for ELU") @@ -74,8 +73,7 @@ def test_qs8_elu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171810227 - Missing recomposition for ELU") @@ -99,6 +97,5 @@ def test_qs8_elu_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/floor.py b/backends/xnnpack/test/ops/floor.py index 31c3da09b42..cb65ca2aa58 100644 --- a/backends/xnnpack/test/ops/floor.py +++ b/backends/xnnpack/test/ops/floor.py @@ -31,8 +31,7 @@ def _test_floor(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_floor_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_floor(self): diff --git a/backends/xnnpack/test/ops/hardswish.py b/backends/xnnpack/test/ops/hardswish.py index d35e7ab5d78..8f6a190412c 100644 --- a/backends/xnnpack/test/ops/hardswish.py +++ b/backends/xnnpack/test/ops/hardswish.py @@ -41,8 +41,7 @@ def _test_hardswish(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T158969708 - Missing recomposition pass for hardswish") @@ -75,6 +74,5 @@ def test_fp32_hardswish_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/hardtanh.py b/backends/xnnpack/test/ops/hardtanh.py index fdcfb7c7efe..d13624663ca 100644 --- a/backends/xnnpack/test/ops/hardtanh.py +++ b/backends/xnnpack/test/ops/hardtanh.py @@ -38,8 +38,7 @@ def test_fp32_hardtanh(self): .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_hardtanh_bound(self): @@ -58,8 +57,7 @@ def test_fp32_hardtanh_bound(self): .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_hardtanh(self): @@ -90,6 +88,5 @@ def test_qs8_hardtanh(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/leaky_relu.py b/backends/xnnpack/test/ops/leaky_relu.py index 477188ed752..ae5f2e3197e 100644 --- a/backends/xnnpack/test/ops/leaky_relu.py +++ b/backends/xnnpack/test/ops/leaky_relu.py @@ -43,8 +43,7 @@ def _test_leaky_relu(self, module, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_leaky_relu(self): @@ -76,8 +75,7 @@ def test_fp32_leaky_relu_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T172863987 - Missing quantizer support.") @@ -107,8 +105,7 @@ def test_qs8_leaky_relu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T172863987 - Missing quantizer support.") @@ -143,6 +140,5 @@ def test_qs8_leaky_relu_default_slope(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index b4a9cb62856..85b760e38ad 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -26,23 +26,27 @@ class TestLinear(unittest.TestCase): def test_fp16_linear(self): for use_bias in (True, False): - self._test_linear( - lambda in_size, out_size: torch.nn.Linear( - in_size, out_size, bias=use_bias # noqa - ), - uses_bias=use_bias, - dtype=torch.float16, - atol=5e-2, - ) + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + num_batch_dims=num_batch_dims, + uses_bias=use_bias, + dtype=torch.float16, + atol=5e-2, + ) def test_fp32_linear(self): for use_bias in (True, False): - self._test_linear( - lambda in_size, out_size: torch.nn.Linear( - in_size, out_size, bias=use_bias # noqa - ), - uses_bias=use_bias, - ) + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + uses_bias=use_bias, + num_batch_dims=num_batch_dims, + ) def test_fp32_addmm(self): """ @@ -63,24 +67,71 @@ def forward(self, x): uses_bias=True, ) + def test_fp32_linear_fused_relu(self): + class LinearReluModule(torch.nn.Module): + def __init__(self, in_size, out_size, use_bias): + super().__init__() + self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) + + def forward(self, x): + return torch.nn.functional.relu(self.linear(x)) + + for use_bias in (True, False): + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: LinearReluModule( + in_size, + out_size, + use_bias, # noqa + ), + uses_bias=use_bias, + num_batch_dims=num_batch_dims, + ) + + def test_qs8_linear_fused_relu(self): + class LinearReluModule(torch.nn.Module): + def __init__(self, in_size, out_size, use_bias): + super().__init__() + self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) + + def forward(self, x): + return torch.nn.functional.relu(self.linear(x)) + + for use_bias in (True, False): + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: LinearReluModule( + in_size, + out_size, + use_bias, # noqa + ), + num_batch_dims=num_batch_dims, + uses_bias=use_bias, + quant=True, + ) + def test_qs8_linear(self): for use_bias in (True, False): - self._test_linear( - lambda in_size, out_size: torch.nn.Linear( - in_size, out_size, bias=use_bias # noqa - ), - uses_bias=use_bias, - ) + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + uses_bias=use_bias, + num_batch_dims=num_batch_dims, + ) @unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.") def test_qd8_per_tensor_linear(self): for uses_bias in (False, True): inputs = (torch.randn(2, 4),) module = torch.nn.Linear(4, 5, bias=uses_bias) + dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},) self._test_dqlinear( module, inputs, + dynamic_shapes=dynamic_shapes, is_per_channel=False, uses_bias=uses_bias, ) @@ -93,6 +144,7 @@ def test_qd8_per_channel_linear(self): self._test_dqlinear( module, inputs, + dynamic_shapes=({0: torch.export.Dim("batch", max=100)},), is_per_channel=True, uses_bias=uses_bias, ) @@ -114,7 +166,7 @@ def test_qd8_per_channel_4w_linear(self): qconfig = self._get_4b_dqconfig() input_channels = [2, 63] output_channels = [1, 8, 127] - batches = [1, 2] + batches = [2, 2] use_bias = [False, True] for bs, bias, ipc, opc in product( @@ -129,13 +181,14 @@ def test_qd8_per_channel_4w_linear(self): self._test_dqlinear( module, inputs, + dynamic_shapes=({0: torch.export.Dim("batch", max=100)},), is_per_channel=True, uses_bias=bias, qconfig=qconfig, ) def test_qd8_per_channel_linear_parallel(self): - in_size = 1 + in_size = 2 input_size = 4 output_size = 5 @@ -165,17 +218,39 @@ def forward(self, x, y): torch.rand(in_size, input_size, dtype=torch.float), torch.rand(in_size, input_size, dtype=torch.float), ) + batch_dim = torch.export.Dim("batch", max=100) + dynamic_shapes = ({0: batch_dim}, {0: batch_dim}) self._test_dqlinear( ParallelLinear(), inputs, + dynamic_shapes=dynamic_shapes, linear_count=2, is_per_channel=True, uses_bias=True, ) + def test_qd8_per_channel_linear_with_two_batch(self): + in_size = 2 + input_size = 4 + output_size = 5 + + linear = torch.nn.Linear(input_size, output_size) + inputs = (torch.randn(2, in_size, input_size, dtype=torch.float),) + batch_dim = torch.export.Dim("batch", max=100) + dynamic_shapes = ({0: batch_dim, 1: batch_dim},) + + self._test_dqlinear( + linear, + inputs, + dynamic_shapes=dynamic_shapes, + linear_count=1, + is_per_channel=True, + uses_bias=True, + ) + def test_qd8_per_channel_linear_sequential(self): - in_size = 1 + in_size = 2 input_size = 4 intermediate_size = 5 output_size = 3 @@ -203,17 +278,20 @@ def forward(self, x): return b inputs = (torch.rand(in_size, input_size, dtype=torch.float),) + dynamic_shapes = ({0: torch.export.Dim("batch", max=100)},) self._test_dqlinear( LinearSequential(), inputs, + dynamic_shapes=dynamic_shapes, linear_count=2, is_per_channel=True, uses_bias=True, + atol=1e-1, ) def test_qd8_per_channel_linear_parellel_and_sequential(self): - in_size = 1 + in_size = 2 input_size = 4 intermediate_size = 5 output_size = 3 @@ -252,50 +330,21 @@ def forward(self, x, y): torch.rand(in_size, input_size, dtype=torch.float), torch.rand(in_size, input_size, dtype=torch.float), ) + dynamic_shapes = ( + {0: torch.export.Dim("batch", max=100)}, + {0: torch.export.Dim("batch2", max=100)}, + ) self._test_dqlinear( - LinearModule(), inputs, linear_count=3, is_per_channel=True, uses_bias=True + LinearModule(), + inputs, + dynamic_shapes=dynamic_shapes, + linear_count=3, + is_per_channel=True, + uses_bias=True, + atol=1e-1, ) - def test_fp32_linear_fused_relu(self): - class LinearReluModule(torch.nn.Module): - def __init__(self, in_size, out_size, use_bias): - super().__init__() - self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) - - def forward(self, x): - return torch.nn.functional.relu(self.linear(x)) - - for use_bias in (True, False): - self._test_linear( - lambda in_size, out_size: LinearReluModule( - in_size, - out_size, - use_bias, # noqa - ), - uses_bias=use_bias, - ) - - def test_qs8_linear_fused_relu(self): - class LinearReluModule(torch.nn.Module): - def __init__(self, in_size, out_size, use_bias): - super().__init__() - self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) - - def forward(self, x): - return torch.nn.functional.relu(self.linear(x)) - - for use_bias in (True, False): - self._test_linear( - lambda in_size, out_size: LinearReluModule( - in_size, - out_size, - use_bias, # noqa - ), - uses_bias=use_bias, - quant=True, - ) - class ManualDQLinear(torch.nn.Module): def __init__( self, @@ -595,8 +644,7 @@ def _test_manual_dq_linear( ) .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=atol, rtol=rtol) + .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype): @@ -677,6 +725,7 @@ def _test_linear( self, make_module, uses_bias, + num_batch_dims=1, quant=False, dtype: torch.dtype = torch.float, atol=1e-03, @@ -693,7 +742,7 @@ def _test_linear( ) ) - in_sizes = [1, 4, 4] + in_sizes = [3, 4, 4] input_sizes = [4, 37, 17] output_sizes = [4, 17, 37] @@ -705,11 +754,19 @@ def _test_linear( in_size = int(in_sizes[i]) input_size = int(input_sizes[i]) output_size = int(output_sizes[i]) + input_shape = [in_size] * num_batch_dims + [input_size] + print(f"Testing input_shape {input_shape} with {output_size} out_channels") module = make_module(input_size, output_size).eval().to(dtype) - inputs = (torch.randn(in_size, input_size).to(dtype),) + inputs = (torch.randn(input_shape).to(dtype),) + dynamic_shape = {} + for i in range(num_batch_dims): + dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size) + + dynamic_shape = (dynamic_shape,) + print(dynamic_shape) - tester = Tester(module, inputs) + tester = Tester(module, inputs, dynamic_shapes=dynamic_shape) if quant: tester.quantize() @@ -731,18 +788,18 @@ def _test_linear( tester.to_executorch() tester.serialize() - tester.run_method() - tester.compare_outputs(qtol=quant, atol=atol) - print("success") + tester.run_method_and_compare_outputs(qtol=quant, atol=atol) def _test_dqlinear( self, module, inputs, + dynamic_shapes, linear_count=1, is_per_channel=False, uses_bias=False, qconfig: Optional[QuantizationConfig] = None, + atol=5e-02, ): aten_op, edge_op = ( ( @@ -761,13 +818,12 @@ def _test_dqlinear( is_dynamic=True, ) - tester = Tester(module, inputs) + tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes) tester.quantize(Quantize(quantization_config=quant_config)) tester.export() tester.check_count({aten_op: linear_count}) tester.check(["torch.ops.quantized_decomposed"]) - tester.dump_artifact() tester.to_edge() tester.check_count({edge_op: linear_count}) @@ -779,5 +835,4 @@ def _test_dqlinear( tester.to_executorch() tester.serialize() - tester.run_method() - tester.compare_outputs(atol=5e-02) + tester.run_method_and_compare_outputs(atol=atol) diff --git a/backends/xnnpack/test/ops/max_dim.py b/backends/xnnpack/test/ops/max_dim.py index b43d1ce4e82..9cab1236e4c 100644 --- a/backends/xnnpack/test/ops/max_dim.py +++ b/backends/xnnpack/test/ops/max_dim.py @@ -37,8 +37,7 @@ def _test_max_dim(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171468483 - Fails to partition due to index output dtype.") @@ -65,6 +64,5 @@ def test_fp32_max_dim_no_indices(self): .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/maximum.py b/backends/xnnpack/test/ops/maximum.py index 5ce05d33e37..feff02744d3 100644 --- a/backends/xnnpack/test/ops/maximum.py +++ b/backends/xnnpack/test/ops/maximum.py @@ -30,8 +30,7 @@ def _test_maximum(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_maximum(self): @@ -64,6 +63,5 @@ def test_fp32_maximum_broadcast(self): .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/maxpool2d.py b/backends/xnnpack/test/ops/maxpool2d.py index 84c76a6e6c9..7e510dd9155 100644 --- a/backends/xnnpack/test/ops/maxpool2d.py +++ b/backends/xnnpack/test/ops/maxpool2d.py @@ -64,8 +64,7 @@ def _test_maxpool2d(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_maxpool2d(self): @@ -135,6 +134,5 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/mean_dim.py b/backends/xnnpack/test/ops/mean_dim.py index b8d7e77a224..750b0e8f508 100644 --- a/backends/xnnpack/test/ops/mean_dim.py +++ b/backends/xnnpack/test/ops/mean_dim.py @@ -33,8 +33,7 @@ def _test_mean_dim(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_mean_dim(self): @@ -85,6 +84,5 @@ def test_qs8_mean_dim(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs(qtol=1) + .run_method_and_compare_outputs(qtol=1) ) diff --git a/backends/xnnpack/test/ops/minimum.py b/backends/xnnpack/test/ops/minimum.py index 5d6f08fd1a2..121fbeb1852 100644 --- a/backends/xnnpack/test/ops/minimum.py +++ b/backends/xnnpack/test/ops/minimum.py @@ -30,8 +30,7 @@ def _test_minimum(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_minimum_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_minimum(self): diff --git a/backends/xnnpack/test/ops/multiply.py b/backends/xnnpack/test/ops/multiply.py index 09f9b39ea60..d151f58bd6a 100644 --- a/backends/xnnpack/test/ops/multiply.py +++ b/backends/xnnpack/test/ops/multiply.py @@ -43,8 +43,7 @@ def _test_mul(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_mul(self): @@ -78,8 +77,7 @@ def test_qs8_mul(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul2(self): @@ -102,8 +100,7 @@ def test_qs8_mul2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul_functional(self): @@ -126,8 +123,7 @@ def test_qs8_mul_functional(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_mul_relu(self): @@ -156,6 +152,5 @@ def test_qs8_mul_relu(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/negate.py b/backends/xnnpack/test/ops/negate.py index b7777136f5a..c4a47bb93ce 100644 --- a/backends/xnnpack/test/ops/negate.py +++ b/backends/xnnpack/test/ops/negate.py @@ -31,8 +31,7 @@ def _test_negate(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_neg_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_negate(self): diff --git a/backends/xnnpack/test/ops/permute.py b/backends/xnnpack/test/ops/permute.py index 3441acb6315..2c995376753 100644 --- a/backends/xnnpack/test/ops/permute.py +++ b/backends/xnnpack/test/ops/permute.py @@ -45,8 +45,7 @@ def _test_permute(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_permute(self): @@ -72,8 +71,7 @@ def test_fp32_permute_copy(self): .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_permute(self): @@ -102,8 +100,7 @@ def test_qs8_permute(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_permute_copy(self): @@ -132,6 +129,5 @@ def test_qs8_permute_copy(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/pow.py b/backends/xnnpack/test/ops/pow.py index b4bd6b5862c..d99f2c546e6 100644 --- a/backends/xnnpack/test/ops/pow.py +++ b/backends/xnnpack/test/ops/pow.py @@ -34,8 +34,7 @@ def _test_pow2(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_pow2(self): diff --git a/backends/xnnpack/test/ops/prelu.py b/backends/xnnpack/test/ops/prelu.py index a4e9ef7df95..985ddecf363 100644 --- a/backends/xnnpack/test/ops/prelu.py +++ b/backends/xnnpack/test/ops/prelu.py @@ -36,8 +36,7 @@ def _test_prelu(self, module, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T158653285 - Missing recomposition for PReLU") diff --git a/backends/xnnpack/test/ops/quantize_per_tensor.py b/backends/xnnpack/test/ops/quantize_per_tensor.py index 82aaca0b6f7..f912428a8ab 100644 --- a/backends/xnnpack/test/ops/quantize_per_tensor.py +++ b/backends/xnnpack/test/ops/quantize_per_tensor.py @@ -39,8 +39,7 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_dequantize_per_tenstor(self): @@ -76,6 +75,5 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/relu.py b/backends/xnnpack/test/ops/relu.py index c52055e45f1..3ab1c72b57d 100644 --- a/backends/xnnpack/test/ops/relu.py +++ b/backends/xnnpack/test/ops/relu.py @@ -33,6 +33,5 @@ def test_fp32_relu(self): .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/sdpa.py b/backends/xnnpack/test/ops/sdpa.py index 5cf8534c928..d68bcab2086 100644 --- a/backends/xnnpack/test/ops/sdpa.py +++ b/backends/xnnpack/test/ops/sdpa.py @@ -70,8 +70,7 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03): ) .to_executorch() .serialize() - .run_method() - .compare_outputs(atol=atol, rtol=rtol) + .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) def test_fp16_sdpa_mask2d(self): diff --git a/backends/xnnpack/test/ops/sigmoid.py b/backends/xnnpack/test/ops/sigmoid.py index be8eda605ee..5ed6fc64402 100644 --- a/backends/xnnpack/test/ops/sigmoid.py +++ b/backends/xnnpack/test/ops/sigmoid.py @@ -32,8 +32,7 @@ def _test_sigmoid(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sigmoid(self): diff --git a/backends/xnnpack/test/ops/slice_copy.py b/backends/xnnpack/test/ops/slice_copy.py index 99b5842313f..2d0f150dd15 100644 --- a/backends/xnnpack/test/ops/slice_copy.py +++ b/backends/xnnpack/test/ops/slice_copy.py @@ -27,8 +27,7 @@ def _test_slice_copy(self, module, inputs, copy_count=1, edge_copy_count=1): .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_slice_copy(self): @@ -143,6 +142,5 @@ def forward(self, x): .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/softmax.py b/backends/xnnpack/test/ops/softmax.py index 43ff89f1206..d3f674d7ae2 100644 --- a/backends/xnnpack/test/ops/softmax.py +++ b/backends/xnnpack/test/ops/softmax.py @@ -38,8 +38,7 @@ def _test_softmax(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_softmax(self): diff --git a/backends/xnnpack/test/ops/sqrt.py b/backends/xnnpack/test/ops/sqrt.py index 99ab8f72340..e2a5f4ac2f6 100644 --- a/backends/xnnpack/test/ops/sqrt.py +++ b/backends/xnnpack/test/ops/sqrt.py @@ -16,6 +16,7 @@ def __init__(self): super().__init__() def forward(self, x): + x = torch.abs(x) z = torch.sqrt(x) return z @@ -31,14 +32,13 @@ def _test_sqrt(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sqrt_default"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sqrt(self): - inputs = (torch.randn(20).to(torch.float16).abs(),) + inputs = (torch.randn(20).to(torch.float16),) self._test_sqrt(inputs) def test_fp32_sqrt(self): - inputs = (torch.randn(20).abs(),) + inputs = (torch.randn(20),) self._test_sqrt(inputs) diff --git a/backends/xnnpack/test/ops/square.py b/backends/xnnpack/test/ops/square.py index faad836becf..02dc12e16e4 100644 --- a/backends/xnnpack/test/ops/square.py +++ b/backends/xnnpack/test/ops/square.py @@ -37,8 +37,7 @@ def _test_square(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_square(self): diff --git a/backends/xnnpack/test/ops/static_constant_pad.py b/backends/xnnpack/test/ops/static_constant_pad.py index 6b8563e291d..c836b404ac7 100644 --- a/backends/xnnpack/test/ops/static_constant_pad.py +++ b/backends/xnnpack/test/ops/static_constant_pad.py @@ -99,8 +99,7 @@ def _test_static_constant_pad_functional(self, inputs): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_static_constant_pad_functional(self): @@ -154,8 +153,7 @@ def forward(self, x): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_static_constant_pad_2d(self): @@ -180,6 +178,5 @@ def test_qs8_static_constant_pad_2d(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/ops/sub.py b/backends/xnnpack/test/ops/sub.py index bcb4f389bd6..d3cc6e8aa80 100644 --- a/backends/xnnpack/test/ops/sub.py +++ b/backends/xnnpack/test/ops/sub.py @@ -39,8 +39,7 @@ def _test_sub(self, inputs): .check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"]) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp16_sub(self): @@ -75,8 +74,7 @@ def test_qs8_sub(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -100,8 +98,7 @@ def test_qs8_sub2(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -125,8 +122,7 @@ def test_qs8_sub3(self): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) @unittest.skip("T171957656 - Quantized sub not implemented.") @@ -166,6 +162,5 @@ def forward(self, x, y): ) .to_executorch() .serialize() - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index ab9b02af4bf..06517c526c8 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -40,8 +40,7 @@ def test_fp32_batch_norm_fusion(self): .to_edge() .run_passes(self.PassStage) .check_count({self.bn_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_batch_norm_fusion(self): @@ -52,8 +51,7 @@ def test_q8_batch_norm_fusion(self): .to_edge() .run_passes(self.PassStage) .check_count({self.bn_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_fp32_batch_norm_no_fusion_doesnt_partition(self): diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index abb18a8c0b2..36e566abc36 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -42,8 +42,7 @@ def test_fp32_channels_last_tagged_reshape_pass(self): self.to_copy_name: num_reshape, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_channels_last_tagged_reshape_pass(self): @@ -64,8 +63,7 @@ def test_qs8_channels_last_tagged_reshape_pass(self): ] * num_reshape ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class ConvRelu(torch.nn.Module): @@ -86,8 +84,7 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self): .check( [self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name] ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self): @@ -109,8 +106,7 @@ def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self): self.to_copy_name, ] ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class Conv2dBnHardtanhMeanSequenceModule(torch.nn.Module): @@ -175,6 +171,5 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self): self.to_copy_name: 4, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_convert_to_linear.py b/backends/xnnpack/test/passes/test_convert_to_linear.py index 783336a01cd..0fa80246fd6 100644 --- a/backends/xnnpack/test/passes/test_convert_to_linear.py +++ b/backends/xnnpack/test/passes/test_convert_to_linear.py @@ -35,6 +35,5 @@ def test_fp32_convert_to_linear(self): .check_count( {"executorch_exir_dialects_edge__ops_aten_linear_default": 1} ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_remove_get_item_pass.py b/backends/xnnpack/test/passes/test_remove_get_item_pass.py index 35bd4d8b966..fa68c403e38 100644 --- a/backends/xnnpack/test/passes/test_remove_get_item_pass.py +++ b/backends/xnnpack/test/passes/test_remove_get_item_pass.py @@ -42,8 +42,7 @@ def test_fp32_max_pool2d_remove_getitem(self): .to_edge() .run_passes(self.PassStage) .check_count({self.max_pool2d_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_max_pool2d_remove_getitem(self): @@ -54,8 +53,7 @@ def test_q8_max_pool2d_remove_getitem(self): .to_edge() .run_passes(self.PassStage) .check_count({self.max_pool2d_name: 1}) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) class MaxModule(torch.nn.Module): @@ -79,8 +77,7 @@ def test_fp32_max_remove_getitem(self): self.amax_name: 1, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) def test_q8_max_remove_getitem(self): @@ -95,6 +92,5 @@ def test_q8_max_remove_getitem(self): self.amax_name: 1, } ) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() ) diff --git a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py index 97c31c3d43a..dc67a6582df 100644 --- a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py @@ -55,8 +55,7 @@ def test_tag_implicit_q_dq_test(self): .export() .to_edge() .run_passes(self.PassStage) - .run_method() - .compare_outputs() + .run_method_and_compare_outputs() .get_artifact(Tester.stage_name(self.PassStage)) ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index ec03fa2529d..e0115a29eef 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -7,6 +7,7 @@ import copy import logging +import random import sys from abc import ABC, abstractmethod from collections import Counter, OrderedDict @@ -26,7 +27,7 @@ ) from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.backend.partitioner import Partitioner -from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program logger = logging.getLogger(__name__) @@ -177,11 +178,18 @@ def graph_module(self) -> str: @register_stage class Export(Stage): - def __init__(self): + def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): self.exported_program = None + self.dynamic_shapes = dynamic_shapes - def run(self, artifact: torch.nn.Module, inputs) -> None: - self.exported_program = export(artifact, inputs) + def run( + self, + artifact: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> None: + self.exported_program = export( + artifact, inputs, dynamic_shapes=self.dynamic_shapes + ) @property def artifact(self) -> ExportedProgram: @@ -261,8 +269,8 @@ def __init__( config: Optional[ExecutorchBackendConfig] = None, ): self.config = config or ExecutorchBackendConfig( - passes=[SpecPropPass()], extract_delegate_segments=True, + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) self.executorch_program = None @@ -334,11 +342,13 @@ def __init__( self, module: torch.nn.Module, inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, ): module.eval() self.original_module = module self.inputs = inputs + self.dynamic_shapes = dynamic_shapes self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) self.pipeline = { self.stage_name(Quantize): [self.stage_name(Export)], @@ -371,6 +381,59 @@ def __init__( # Artifact output from stage self.stage_output = None + def generate_random_inputs(self): + # Get shapes of inputs + input_shapes = [] + if self.dynamic_shapes is None: + for tensor_arg in self.inputs: + assert isinstance(tensor_arg, torch.Tensor) + input_shapes.append(tensor_arg.shape) + else: + # Random shapes depending on dynamic shape constraint + dim_name_to_size = {} + for arg_idx in range(len(self.inputs)): + assert isinstance(self.inputs[arg_idx], torch.Tensor) + ex_shape = list(self.inputs[arg_idx].shape) + dynamic_dim_spec = self.dynamic_shapes[arg_idx] + for dim_idx, dim_spec in dynamic_dim_spec.items(): + assert dim_idx < len(ex_shape) + if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): + # derived dims are of the form {0: 2 * torch.export.Dim() // 2} + # The root contains the min/max of the export dim and fn contains + # the function to compute the derived dim. + dim_spec = dim_spec.root + fn = dim_spec.fn + elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): + # Not derived dim so fn is just itself + def fn(x): + return x + + else: + raise RuntimeError( + f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" + ) + dim_name = dim_spec.__name__ + if dim_name not in dim_name_to_size: + upper_bound = min( + dim_spec.max, 1000 + ) # unbounded int max is too large + lower_bound = ( + dim_spec.min if dim_spec.min != 2 else 1 + ) # 0/1 specialization means dim_spec.min can never be 1 + dim_name_to_size[dim_name] = fn( + random.randint(lower_bound, upper_bound) + ) + ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] + input_shapes.append(torch.Size(ex_shape)) + # create random tensor inputs with the shapes given above: + random_inputs = [] + for arg_idx in range(len(self.inputs)): + random_inputs.append( + torch.randn(input_shapes[arg_idx]).to(dtype=self.inputs[arg_idx].dtype) + ) + + yield tuple(random_inputs) + @staticmethod def stage_name(stage) -> str: t = stage if isinstance(stage, type) else type(stage) @@ -406,7 +469,9 @@ def quantize(self, quantize_stage: Optional[Quantize] = None): return self._run_stage(quantize_stage or Quantize(), self.inputs) def export(self, export_stage: Optional[Export] = None): - return self._run_stage(export_stage or Export(), self.inputs) + return self._run_stage( + export_stage or Export(dynamic_shapes=self.dynamic_shapes), self.inputs + ) def to_edge(self, to_edge_stage: Optional[ToEdge] = None): return self._run_stage(to_edge_stage or ToEdge()) @@ -469,21 +534,39 @@ def check_node_count(self, input: Dict[Any, int]): return self - def run_method( - self, stage: Optional[str] = None, inputs: Optional[Tuple[torch.Tensor]] = None + def run_method_and_compare_outputs( + self, + stage: Optional[str] = None, + inputs: Optional[Tuple[torch.Tensor]] = None, + num_runs=1, + atol=1e-03, + rtol=1e-03, + qtol=0, ): - inputs_to_run = inputs or self.inputs - export_stage = self.stages[self.stage_name(Export)] - - # Reference output (and quantization scale) - ( - self.reference_output, - self.quantization_scale, - ) = self._calculate_reference_output(export_stage.artifact, inputs_to_run) + number_of_runs = 1 if inputs is not None else num_runs + reference_stage = self.stages[self.stage_name(Export)] - # Output from running artifact at stage stage = stage or self.cur - self.stage_output = self.stages[stage].run_artifact(inputs_to_run) + + print(f"Comparing Stage {stage} with Stage {reference_stage}") + for run_iteration in range(number_of_runs): + inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) + input_shapes = [generated_input.shape for generated_input in inputs_to_run] + print(f"Run {run_iteration} with input shapes: {input_shapes}") + + # Reference output (and quantization scale) + ( + reference_output, + quantization_scale, + ) = self._calculate_reference_output( + reference_stage.artifact, inputs_to_run + ) + + # Output from running artifact at stage + stage_output = self.stages[stage].run_artifact(inputs_to_run) + self._compare_outputs( + reference_output, stage_output, quantization_scale, atol, rtol, qtol + ) return self @@ -521,33 +604,37 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): f"\t Min: {model.min()}, {ref.min()}\n" ) - def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0): + @staticmethod + def _compare_outputs( + reference_output, + stage_output, + quantization_scale=None, + atol=1e-03, + rtol=1e-03, + qtol=0, + ): """ Compares the original of the original nn module with the output of the generated artifact. This requres calling run_method before calling compare_outputs. As that runs the generated artifact on the sample inputs and sets the stage output to be compared against the reference. """ - assert self.reference_output is not None - assert self.stage_output is not None - # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor - if isinstance(self.reference_output, torch.Tensor): - self.reference_output = (self.reference_output,) - if isinstance(self.stage_output, torch.Tensor): - self.stage_output = (self.stage_output,) + if isinstance(reference_output, torch.Tensor): + reference_output = (reference_output,) + if isinstance(stage_output, torch.Tensor): + stage_output = (stage_output,) # If a qtol is provided and we found an dequantization node prior to the output, relax the # atol by qtol quant units. - if self.quantization_scale is not None: - atol += self.quantization_scale * qtol + if quantization_scale is not None: + atol += quantization_scale * qtol - self._assert_outputs_equal( - self.stage_output, - self.reference_output, + Tester._assert_outputs_equal( + stage_output, + reference_output, atol=atol, rtol=rtol, ) - return self @staticmethod def _calculate_reference_output(