diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 1a2c5a9709f..54b4a8b83f3 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -614,6 +614,7 @@ python_unittest( typing = True, deps = [ ":typing_stubs", + "//executorch/backends/cadence/aot:ops_registrations", "//executorch/backends/cadence/aot:ref_implementations", "//caffe2:torch", ] diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index ce0fba47610..507562526c5 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -1449,7 +1449,7 @@ def quantized_layer_norm_meta( input: torch.Tensor, X_scale: torch.Tensor, X_zero_point: torch.Tensor, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -1464,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta( input: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 0cd55326b86..aeb62a19784 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -64,9 +64,9 @@ def quantize_per_tensor( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) + quantized = torch.round(input_tensor * scale + zero_point).to(dtype) return torch.max( - torch.min(dequantized, torch.tensor(quant_max)), + torch.min(quantized, torch.tensor(quant_max)), torch.tensor(quant_min), ) @@ -247,12 +247,12 @@ def quantized_linear( ).reshape(*leading_dims, N) -@impl(m, "quantized_layer_norm_per_tensor") +@impl(m, "quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor( input_tensor, X_scale, X_zero_point, -128, 127, torch.float32 ) out = torch.nn.functional.layer_norm( - float_input_tensor, (normalized_shape,), weight, bias, eps=eps + float_input_tensor, normalized_shape, weight, bias, eps=eps ) return quantize_per_tensor( @@ -365,7 +365,7 @@ def quantized_conv_per_tensor( ) -@impl(m, "quantized_conv_nchw_per_tensor") +@impl(m, "quantized_conv_nchw.per_tensor") def quantized_conv_nchw_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor( ) -@impl(m, "quantized_conv_nhwc_per_tensor") +@impl(m, "quantized_conv_nhwc.per_tensor") def quantized_conv_nhwc_per_tensor( input_tensor: torch.Tensor, weight: torch.Tensor, @@ -558,62 +558,62 @@ def variant( return decorator -@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nchw", torch.int8, torch.int8) def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nchw", torch.uint8, torch.uint8) def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") @quantized_conv_variant("nhwc", torch.int8, torch.int8) def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... -@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") @quantized_conv_variant("nhwc", torch.uint8, torch.uint8) def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 4e2829a8460..918324876bf 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -8,31 +8,11 @@ import typing import unittest +import executorch.backends.cadence.aot.ops_registrations # noqa +import executorch.backends.cadence.aot.ref_implementations # noqa + import numpy as np import torch - -from executorch.backends.cadence.aot.ref_implementations import ( - dequantize_per_tensor, - quantize_per_tensor, - quantized_add, - quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_per_tensor, - quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_per_tensor, - quantized_layer_norm_per_tensor, - quantized_linear, - quantized_relu, -) from executorch.backends.cadence.aot.typing_stubs import expand @@ -60,7 +40,7 @@ def test_quantize_per_tensor( zero_point = round(-f_min * inv_scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) - output = quantize_per_tensor( + output = torch.ops.cadence.quantize_per_tensor( input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype ) @@ -104,7 +84,7 @@ def test_dequantize_per_tensor( zero_point = round(-f_min / scale) + q_min expected_output = torch.tensor([expected_value], dtype=torch.float32) - output = dequantize_per_tensor( + output = torch.ops.cadence.dequantize_per_tensor( input_tensor, scale, zero_point, q_min, q_max, torch.float32 ) @@ -142,7 +122,7 @@ def test_quantized_add( Y_tensor = torch.tensor([Y], dtype=dtype) expected_output = torch.tensor([expected_value], dtype=dtype) - output = quantized_add( + output = torch.ops.cadence.quantized_add( X_tensor, torch.tensor(X_scale), torch.tensor(X_zero_point, dtype=dtype), @@ -238,7 +218,7 @@ def test_quantized_linear( .to(expected_output.dtype) ) bias = torch.arange(weight_shape[0]).to(expected_output.dtype) - output = quantized_linear( + output = torch.ops.cadence.quantized_linear( src, weight, bias, @@ -266,7 +246,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.1, 0.1] 0.1, # X_scale 0, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor([1.0, 1.0]), # weight torch.tensor([0.0, 0.0]), # bias 1e-5, # eps @@ -282,7 +262,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.05, 0.05] 0.05, # X_scale 128, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor([1.0, 1.0]), # weight torch.tensor([0.0, 0.0]), # bias 1e-5, # eps @@ -298,7 +278,7 @@ def test_quantized_linear( ), # input: dequantized to [-0.2, 0.2] 0.1, # X_scale 0, # X_zero_point - 2, # normalized_shape (last dimension) + [2], # normalized_shape (last dimension) torch.tensor( [2.0, 0.5] ), # weight: scale first element by 2, second by 0.5 @@ -318,7 +298,7 @@ def test_quantized_layer_norm_per_tensor( input_tensor: torch.Tensor, X_scale: float, X_zero_point: int, - normalized_shape: int, + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -327,7 +307,7 @@ def test_quantized_layer_norm_per_tensor( dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = quantized_layer_norm_per_tensor( + output = torch.ops.cadence.quantized_layer_norm.per_tensor( input_tensor, X_scale, X_zero_point, @@ -372,10 +352,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.1, # output_scale 0, # output_zero_point - torch.tensor( - [1073741824], dtype=torch.int32 - ), # out_multiplier (0.5 * 2^31) - torch.tensor([0], dtype=torch.int8), # out_shift + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[50]]]], dtype=torch.int8 @@ -403,8 +381,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.25, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[48, 64], [96, 112]]]], dtype=torch.int8 @@ -432,8 +410,8 @@ def test_quantized_layer_norm_per_tensor( 0.1, # bias_scale 0.1, # output_scale 128, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.uint8, # dtype torch.tensor( [[[[238]]]], dtype=torch.uint8 @@ -463,8 +441,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.5, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[6, 10, 14]]], dtype=torch.int8 @@ -498,8 +476,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.2, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[25]], [[50]]]], dtype=torch.int8 @@ -539,8 +517,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.1, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int16, # dtype torch.tensor( [[[[180]]]], dtype=torch.int16 @@ -592,8 +570,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.05, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int16, # dtype torch.tensor([[[[400]], [[200]]]], dtype=torch.int16), memory_format, @@ -635,8 +613,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.2, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[50]], [[65]]]], dtype=torch.int8 @@ -674,8 +652,8 @@ def test_quantized_layer_norm_per_tensor( 1.0, # bias_scale 0.5, # output_scale 0, # output_zero_point - typing.cast(None, torch.Tensor), - typing.cast(None, torch.Tensor), + 0, # unused out_multiplier + 0, # unused out_shift torch.int8, # dtype torch.tensor( [[[[2, 10, 8], [28, 68, 40], [26, 58, 32]]]], dtype=torch.int8 @@ -715,9 +693,9 @@ def test_quantized_conv_per_tensor( convs = [ ( - quantized_conv_nchw_per_tensor + torch.ops.cadence.quantized_conv_nchw.per_tensor if memory_format == torch.contiguous_format - else quantized_conv_nhwc_per_tensor + else torch.ops.cadence.quantized_conv_nhwc.per_tensor ) ] @@ -725,30 +703,30 @@ def test_quantized_conv_per_tensor( if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8: if input_tensor.is_contiguous(memory_format=torch.contiguous_format): optimized_convs = [ - quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor, + torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, ] else: optimized_convs = [ - quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor, - quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor, + torch.ops.cadence.quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, ] elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8: if input_tensor.is_contiguous(memory_format=torch.contiguous_format): optimized_convs = [ - quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor, + torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, ] else: optimized_convs = [ - quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor, - quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor, + torch.ops.cadence.quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor, + torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, ] convs.extend(optimized_convs) @@ -851,7 +829,7 @@ def test_quantized_relu( dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - output = quantized_relu( + output = torch.ops.cadence.quantized_relu( X, X_zero_point, out_zero_point, out_multiplier, out_shift )