diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 7a022b54395..357800865cb 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -53,10 +53,9 @@ def define_node( [ts.DType.INT8, ts.DType.INT32], output.tosa_spec, ) - scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( tosa_graph, inputs, node, self.tosa_spec ) else: @@ -85,7 +84,12 @@ def define_node( # Scale output back to 8 bit # pyre-ignore tqutils.insert_rescale_op_to_int8( - tosa_graph, add_output, scale_back, node, self.tosa_spec + tosa_graph, + add_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, ) # type: ignore[possibly-undefined] diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 18b3c853271..4701c488967 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -56,7 +56,7 @@ def define_node( scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( tosa_graph, inputs, node, self.tosa_spec ) else: @@ -86,7 +86,12 @@ def define_node( # Scale output back to 8 bit # pyre-ignore tqutils.insert_rescale_op_to_int8( - tosa_graph, sub_output, scale_back, node, self.tosa_spec + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, ) # type: ignore[possibly-undefined] diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index e6eac4d80eb..ff1ad50e517 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -473,6 +473,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in ( + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.sub_.Tensor, torch.ops.aten.matmul.default, torch.ops.aten.mm.default, torch.ops.aten.bmm.default, @@ -485,10 +489,6 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in ( - torch.ops.aten.add.Tensor, - torch.ops.aten.add_.Tensor, - torch.ops.aten.sub.Tensor, - torch.ops.aten.sub_.Tensor, torch.ops.aten.minimum.default, torch.ops.aten.maximum.default, ): diff --git a/backends/arm/test/misc/test_conv_relu_residual_add.py b/backends/arm/test/misc/test_conv_relu_residual_add.py new file mode 100644 index 00000000000..fdd6ec972a6 --- /dev/null +++ b/backends/arm/test/misc/test_conv_relu_residual_add.py @@ -0,0 +1,110 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest + +import torch +import torch.nn as nn +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, +) + + +# Model with Conv1D - ReLU sequence and a residual add. +# Testing the annotation of Conv1D-ReLU(to be fused) and annotation of add. +# ReLU outputs positive numbers and linear outputs positive and negative numbers, so they +# should have different quantisation parameters. If the ReLU gets wrong quantisation parameters(e.g. qmin!=zp) +# because of a shared observer of a following operators(e.g. add), the Conv1D-ReLU sequence is not fused +# and is left in FP32. As a result, the test fails. +class AddDifferentRanges(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, input_dim): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size) + self.relu = torch.nn.ReLU() + self.linear = nn.Linear(out_channels, out_channels) + + def forward(self, x): + # Permute: (N, T, C) -> (N, C, T) + x = x.permute(0, 2, 1) + x = self.conv1(x) + x = self.relu(x) + x = x.permute(0, 2, 1) + out = x + self.linear(x) + return out + + +input_t = Tuple[torch.Tensor] +model = AddDifferentRanges(in_channels=3, out_channels=16, kernel_size=3, input_dim=10) +model_inputs = (torch.randn(1, 10, 3),) +quant_test_data = { + "per_channel_quantization=true": True, + "per_channel_quantization=false": False, +} + + +def test_tosa_FP(): + pipeline = TosaPipelineFP[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("per_channel_quantization", quant_test_data) +def test_tosa_INT(per_channel_quantization): + pipeline = TosaPipelineINT[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + qtol=0, + ) + pipeline.run() + + +@pytest.mark.slow +@common.XfailIfNoCorstone300 +@common.parametrize("per_channel_quantization", quant_test_data) +def test_tosa_u55_INT(per_channel_quantization): + pipeline = EthosU55PipelineINT[input_t]( + model, + model_inputs, + [], + [], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + qtol=0, + ) + pipeline.run() + + +@pytest.mark.slow +@common.XfailIfNoCorstone320 +@common.parametrize("per_channel_quantization", quant_test_data) +def test_tosa_u85_INT(per_channel_quantization): + pipeline = EthosU85PipelineINT[input_t]( + model, + model_inputs, + [], + [], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + qtol=0, + ) + pipeline.run() diff --git a/backends/arm/test/models/test_inception_v3_arm.py b/backends/arm/test/models/test_inception_v3_arm.py index f69022de712..f973521c1fa 100644 --- a/backends/arm/test/models/test_inception_v3_arm.py +++ b/backends/arm/test/models/test_inception_v3_arm.py @@ -51,7 +51,7 @@ def test_ic3_tosa_BI(): aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, - atol=0.6, + atol=0.65, qtol=1, ) pipeline.run() diff --git a/backends/arm/test/models/test_resnet18.py b/backends/arm/test/models/test_resnet18.py new file mode 100644 index 00000000000..6e965daeb8b --- /dev/null +++ b/backends/arm/test/models/test_resnet18.py @@ -0,0 +1,99 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, +) + +from torchvision import transforms # type: ignore[import-untyped] +from torchvision.models import resnet18, ResNet18_Weights + +model = resnet18(weights=ResNet18_Weights) +model = model.eval() +normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +model_inputs = (normalize(torch.randn((1, 3, 224, 224))),) + +input_t = Tuple[torch.Tensor] + + +quant_test_data = { + "per_channel_quantization=true": True, + "per_channel_quantization=false": False, +} + + +def test_resnet_tosa_FP(): + pipeline = TosaPipelineFP[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("per_channel_quantization", quant_test_data) +def test_resnet_tosa_INT(per_channel_quantization): + pipeline = TosaPipelineINT[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + atol=0.5, + qtol=1, + ) + pipeline.run() + + +@pytest.mark.slow +@common.XfailIfNoCorstone300 +@common.parametrize("per_channel_quantization", quant_test_data) +def test_resnet_u55_INT(per_channel_quantization): + pipeline = EthosU55PipelineINT[input_t]( + model, + model_inputs, + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + atol=0.5, + qtol=1, + ) + pipeline.run() + + +@pytest.mark.slow +@pytest.mark.xfail( + reason="For resnet18 for Ethos-U85, the SRAM memory footprint is very high. The compiler team is investigating." +) +@common.XfailIfNoCorstone320 +@common.parametrize("per_channel_quantization", quant_test_data) +def test_resnet_u85_INT(per_channel_quantization): + pipeline = EthosU85PipelineINT[input_t]( + model, + model_inputs, + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + per_channel_quantization=per_channel_quantization, + atol=0.5, + qtol=1, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 0e825d57894..5c27761e926 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -57,13 +57,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): "4d_randn_1": lambda: (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), "4d_randn_2": lambda: (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), "4d_randn_big": lambda: ( - 10000 * torch.randn(1, 1, 4, 4), + (1 << 30) * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1), ), "4d_randn_1_mutltiple_broadcasts": lambda: ( torch.randn(1, 4, 4, 1), torch.ones(1, 1, 4, 4), ), + "4d_big_small": lambda: ( + (10e10) * torch.randn(1, 10, 20, 30), + torch.randn(1, 10, 20, 30), + ), } @@ -86,7 +90,7 @@ def test_add_tensor_tosa_FP(test_data: input_t1): @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0) pipeline.run() @@ -111,9 +115,16 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1): quant_max=2**31 - 1, quant_min=-(2**31), ) + output_act_qspec = QuantizationSpec( + torch.int32, + observer, + qscheme=torch.per_tensor_symmetric, + quant_max=2**31 - 1, + quant_min=-(2**31), + ) # This quantization_config will be set as global config. quantization_config = arm_quantizer.QuantizationConfig( - input_act_qspec, None, None, None + input_act_qspec, output_act_qspec, None, None ) quantize_stage = Quantize(quantizer, quantization_config) pipeline.change_args("quantize", quantize_stage) @@ -157,13 +168,13 @@ def test_add_tensor_tosa_FP_3(test_data: input_t2): @common.parametrize("test_data", Add3.test_data) def test_add_tensor_tosa_INT_3(test_data: input_t2): - pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op, qtol=0) pipeline.run() @common.parametrize("test_data", Add2.test_data) def test_add_tensor_tosa_INT_2(test_data: input_t2): - pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op, qtol=0) pipeline.run() diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 76502daf45c..a7a031468ea 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -47,28 +47,28 @@ def __init__(self): # (t, c, n, s) = (6, 96, 1, 1) # 1. 1x1 CONV2d + ReLU6 (Pointwise) self.pointwise_conv2d = torch.nn.Conv2d( - in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1 + in_channels=16, out_channels=96, kernel_size=1, stride=1, groups=1 ) ## (1, 128, 81, 81) - self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False) + self.batch_norm2d_16 = torch.nn.BatchNorm2d(96, affine=False) self.relu6 = torch.nn.ReLU6() # 2. 3x3 DepthwiseConv2d + ReLu6 self.depthwise_conv2d = torch.nn.Conv2d( - in_channels=128, - out_channels=128, + in_channels=96, + out_channels=96, kernel_size=3, padding=1, stride=1, - groups=128, + groups=96, ) ## (1, 128, H, W) # 3. Linear 1x1 Conv2d self.pointwise_conv2d_linear = torch.nn.Conv2d( - in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1 + in_channels=96, out_channels=16, kernel_size=1, stride=1, groups=1 ) ## (1, 32, 81, 81) def get_inputs(self) -> Tuple[torch.Tensor]: - return (torch.randn(1, 32, 81, 81),) + return (torch.randn(1, 16, 81, 81),) def forward(self, x): input = x diff --git a/backends/arm/test/ops/test_group_norm.py b/backends/arm/test/ops/test_group_norm.py index 5fa4cd328de..0f314064548 100644 --- a/backends/arm/test/ops/test_group_norm.py +++ b/backends/arm/test/ops/test_group_norm.py @@ -102,6 +102,9 @@ def test_native_group_norm_tosa_INT(test_data): "test_data", test_data_suite, xfails={ + "rand_4_6_8_groups_2_eps_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_4_6_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_4_6_groups_2": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", "randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", "rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", "rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", diff --git a/backends/arm/test/ops/test_leaky_relu.py b/backends/arm/test/ops/test_leaky_relu.py index c18255a73c0..432c4da7ecc 100644 --- a/backends/arm/test/ops/test_leaky_relu.py +++ b/backends/arm/test/ops/test_leaky_relu.py @@ -30,10 +30,10 @@ def forward(self, x: torch.Tensor): test_data: dict[str, input_t1] = { "zeros": lambda: ((torch.zeros(1, 1, 5, 5),), 0.01), - "ones": lambda: ((torch.ones(1, 32, 112, 112),), 0.01), - "rand": lambda: ((torch.rand(1, 96, 56, 56),), 0.2), + "ones": lambda: ((torch.ones(1, 16, 96, 96),), 0.01), + "rand": lambda: ((torch.rand(1, 64, 56, 56),), 0.2), "3Dtensor": lambda: ((torch.rand(5, 5, 5),), 0.001), - "negative_slope": lambda: ((torch.rand(1, 16, 128, 128),), -0.002), + "negative_slope": lambda: ((torch.rand(1, 16, 96, 96),), -0.002), } diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index 6872cf2b86b..5a775612d4f 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -176,7 +176,7 @@ def test_sigmoid_u85_INT(test_data): "ramp": "AssertionError: Output 0 does not match reference output. MLETORCH-787" }, ) -@pytest.mark.flaky(reruns=5) # MLETORCH-787: Investigate int16-int8 rescaling precision +@pytest.mark.xfail # MLETORCH-787: Investigate int16-int8 rescaling precision @common.XfailIfNoCorstone320 def test_sigmoid_u85_INT_add_sigmoid(test_data): pipeline = EthosU85PipelineINT( diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index ab6612393b8..c691506beb2 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -5,9 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import math from typing import Tuple import torch + from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -36,6 +39,10 @@ "rand_2D_4x4": lambda: (torch.rand(4, 4), torch.rand(4, 4)), "rand_3D_4x4x4": lambda: (torch.rand(4, 2, 2), torch.rand(4, 2, 2)), "rand_4D_2x2x4x4": lambda: (torch.rand(2, 2, 4, 4), torch.rand(2, 2, 4, 4)), + "rand_4D_big_small": lambda: ( + (10e30) * torch.randn(1, 20, 30, 40), + torch.randn(1, 20, 30, 40), + ), "zeros": lambda: (torch.rand(4, 4), torch.zeros(4, 4)), "randn_4D_mutltiple_broadcasts": lambda: ( torch.randn(1, 4, 4, 1), @@ -45,6 +52,16 @@ "rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1), } +# Sub and tan - the tan has a really steep curve just before Pi/2 and a point of discontinuity at Pi/2 +# so if the sub result is inaccurate, the error will be amplified by the tan +sub_tan_test_data = { + "rand_4D_pi": lambda: ( + torch.randn(1, 10, 20, 30) * math.pi / 2, + torch.randn(1, 10, 20, 30) * math.pi / 2, + ), + "rand_3D_pi": lambda: (torch.randn(1, 30, 40) * math.pi / 2, torch.rand(1, 30, 40)), +} + class Sub(torch.nn.Module): def forward(self, x: torch.Tensor): @@ -56,6 +73,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return x - y +class SubTan(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + z = x - y + t = torch.tan(z) + return t + + input_t1 = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y @@ -87,23 +112,22 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]): @common.parametrize("test_data", sub_test_data) def test_sub_tensor_tosa_INT(test_data): """Test Subtraction (TOSA INT)""" - pipeline = TosaPipelineINT[input_t1]( - Sub(), - test_data(), - aten_op, - exir_op, - ) + pipeline = TosaPipelineINT[input_t1](Sub(), test_data(), aten_op, exir_op, qtol=0) pipeline.run() @common.parametrize("test_data", sub2_test_data) def test_sub_tensor_tosa_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction (TOSA INT)""" + pipeline = TosaPipelineINT[input_t2](Sub2(), test_data(), aten_op, exir_op, qtol=0) + pipeline.run() + + +@common.parametrize("test_data", sub_tan_test_data) +def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]): """Test Two-Operand Subtraction (TOSA INT)""" pipeline = TosaPipelineINT[input_t2]( - Sub2(), - test_data(), - aten_op, - exir_op, + SubTan(), test_data(), aten_op, exir_op, qtol=0 ) pipeline.run() diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index ae549ee9345..3edf40f9eaa 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -15,11 +15,69 @@ import torch.fx import torch.fx.node +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) + from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node from tosa.RoundingMode import RoundingMode # type: ignore +def insert_rescale_ops_to_int32_maxscale( + tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None +) -> tuple[list[Any], float]: + """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale)) + compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision + for the computation without overflowing. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + """ + + if len(inputs) > 2: + raise ValueError("More than two inputs not supported") + + tensors = inputs.copy() + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + lhs_qparams, rhs_qparams = input_qparams.values() + lhs_scale = lhs_qparams.get_scale_per_tensor() + rhs_scale = rhs_qparams.get_scale_per_tensor() + # Common scale for the two numbers + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + SHIFT_INT8 = 20 + # We are adding two int8 numbers. If the zero point is non-null, the result will be in the range [-255;255], therefore we need 9 bits for the result. + # We have a 32-bit accumulator, so we can shift to the left by 20 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) + # we are shifting to the left by 19. + lhs_factor = (1 << SHIFT_INT8) * lhs_scale / max_scale_2x + rhs_factor = (1 << SHIFT_INT8) * rhs_scale / max_scale_2x + rescaled_lhs = build_rescale_to_int32( + tosa_graph, + tensors[0], + lhs_qparams.get_zp_per_tensor(), + lhs_factor, + tosa_spec=tosa_spec, + ) + rescaled_rhs = build_rescale_to_int32( + tosa_graph, + tensors[1], + rhs_qparams.get_zp_per_tensor(), + rhs_factor, + tosa_spec=tosa_spec, + ) + out_qparam = get_output_qparams(node)[0] + out_scale = out_qparam.get_scale_per_tensor() + back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT8)) + + return [rescaled_lhs, rescaled_rhs], back_scale + + def insert_rescale_ops_to_int32( tosa_graph: Any, inputs: list[TosaArg], @@ -71,6 +129,7 @@ def insert_rescale_op_to_int8( last_tensor: TosaArg, scale: float, node: Node, + compute_rescale=True, tosa_spec=None, ) -> None: """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. @@ -78,6 +137,7 @@ def insert_rescale_op_to_int8( node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' + compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. tosa_graph: the tosa_graph to manipulate. This functions is used in serialization to TOSA for target ops that are @@ -89,10 +149,14 @@ def insert_rescale_op_to_int8( ) output_qparams = get_output_qparams(node) - assert len(output_qparams) == 1, "More than one output not supported" + if len(output_qparams) != 1: + raise ValueError("More than one output not supported") qargs_out = output_qparams[0] - output_rescale_scale = scale / qargs_out.get_scale_per_tensor() + if compute_rescale: + output_rescale_scale = scale / qargs_out.get_scale_per_tensor() + else: + output_rescale_scale = scale # Rescale Back to INT8 build_rescale_from_int32(