diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index fc6b5217f0c..121d527cb46 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -165,7 +165,7 @@ def prepack_ref(self, ref: ValueRef) -> bool: else: return ref.supports_prepack and self.should_prepack - def create_value_for(self, ref: ValueRefList) -> str: + def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 if isinstance(ref, list): ret_str = "" for r in ref: diff --git a/docs/source/build-run-xtensa.md b/docs/source/build-run-xtensa.md index 296d9ac1193..7827ea5e36d 100644 --- a/docs/source/build-run-xtensa.md +++ b/docs/source/build-run-xtensa.md @@ -68,13 +68,14 @@ examples/xtensa/ ├── aot ├── kernels ├── ops +├── tests ├── third-party └── utils ``` ***AoT (Ahead-of-Time) Components***: -The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) defines a model and some example inputs (set to a vector of ones), and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders. +The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders. ***Operators***: @@ -97,17 +98,31 @@ cd executorch python3 -m examples.portable.scripts.export --model_name="add" ``` -***Quantized Linear***: +***Quantized Operators***: -The second, more complex model is a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py#L88). Linear is the backbone of most Automatic Speech Recognition (ASR) models. +The other, more complex model are custom operators, including: + - a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models. + - a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models. -The generated file is called `XtensaDemoModel.pte`. +In both cases the generated file is called `XtensaDemoModel.pte`. + +```bash +cd executorch +python3 -m examples.xtensa.tests.quantized__example +``` + +***Small Model: RNNT predictor***: + +The torchaudio [RNNT-emformer](https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html) model is an Automatic Speech Recognition (ASR) model, comprised of three different submodels: an encoder, a predictor and a joiner. +The predictor is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using: ```bash cd executorch -python3 -m examples.xtensa.aot.export_example +python3 -m examples.xtensa.tests.rnnt_predictor_quantized_example ``` +The generated file is called `XtensaDemoModel.pte`. + ### Runtime **Building the DSP firmware image** @@ -139,12 +154,14 @@ cmake -DBUCK2=buck2 \ -DCMAKE_TOOLCHAIN_FILE=/examples/xtensa/xtensa.cmake \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Debug \ + -DPYTHON_EXECUTABLE=python3 \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_HOST_TARGETS=ON \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ + -DEXECUTORCH_BUILD_PTHREADPOOL=OFF \ + -DEXECUTORCH_BUILD_CPUINFO=OFF \ -DEXECUTORCH_BUILD_FLATC=OFF \ -DFLATC_EXECUTABLE="$(which flatc)" \ - -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ - -DPYTHON_EXECUTABLE=python3 \ -Bcmake-out . cmake --build cmake-out -j8 --target install --config Debug @@ -196,6 +213,6 @@ First 20 elements of output 0 In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip. -The model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model in [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels). +The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels). Other models can be created following the same structure, always assuming that operators and kernels are available. diff --git a/examples/xtensa/aot/compiler.py b/examples/xtensa/aot/compiler.py new file mode 100644 index 00000000000..c9df9ef2bab --- /dev/null +++ b/examples/xtensa/aot/compiler.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any, Callable + +import torch + +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge + +from torch.export import export +from torch.export.exported_program import ExportedProgram + + +def export_program( + model: Callable, + inputs: Any, + pt2_quant: bool = False, +) -> ExportedProgram: + # we don't support training mode. Make it eval + if hasattr(model, "eval"): + if pt2_quant: + # pyre-fixme[6]: Incompatible parameter type. + torch.ao.quantization.move_exported_model_to_eval(model) + else: + # pyre-fixme[16]: Anonymous callable has no attribute `eval`. + model.eval() + + # if it's already an ExportedProgram, just return it + if isinstance(model, ExportedProgram): + return model + + assert isinstance(model, torch.nn.Module), "model should be an nn.Module" + + # Prevent mkldnn decompositions + torch._C._set_mkldnn_enabled(False) + + # else: capture the model and return it. + return export(model, inputs) + + +# Export the model and lower it it edge IR. +def export_to_edge( + model: Callable, + inputs: Any, + pt2_quant: bool = False, + dump_graphs: bool = False, +) -> EdgeProgramManager: + # Export the model into an ExportedProgram. + expo_program = export_program(model, inputs, pt2_quant) + + if dump_graphs: + logging.info(f"Exported graph:\n{expo_program.graph_module.graph}") + + # Call to_edge to convert the graph to edge IR. + edge_prog_manager = to_edge( + expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ) + + if dump_graphs: + logging.info( + f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}" + ) + + return edge_prog_manager diff --git a/examples/xtensa/aot/export_example.py b/examples/xtensa/aot/export_example.py index b51f5c9b498..509538f5437 100644 --- a/examples/xtensa/aot/export_example.py +++ b/examples/xtensa/aot/export_example.py @@ -10,18 +10,17 @@ from .meta_registrations import * # noqa -import torch -from executorch.exir import EdgeCompileConfig from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from ...portable.utils import export_to_edge, save_pte_program +from ...portable.utils import save_pte_program +from .compiler import export_to_edge from .quantizer import ( QuantFusion, ReplacePT2DequantWithXtensaDequant, ReplacePT2QuantWithXtensaQuant, - XtensaQuantizer, + XtensaBaseQuantizer, ) @@ -29,28 +28,9 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) -if __name__ == "__main__": - in_features = 32 - out_features = 16 - bias = True - shape = [64, in_features] - - class QuantizedLinear(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool): - super().__init__() - self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias) - - def forward(self, x: torch.Tensor): - output_linear_out = self.output_linear(x) - return output_linear_out - - model = QuantizedLinear(in_features, out_features, bias) - model.eval() - - example_inputs = (torch.ones(shape),) - +def export_xtensa_model(model, example_inputs): # Quantizer - quantizer = XtensaQuantizer() + quantizer = XtensaBaseQuantizer() # Export model_exp = capture_pre_autograd_graph(model, example_inputs) @@ -66,29 +46,20 @@ def forward(self, x: torch.Tensor): patterns = [q.pattern for q in quantizer.quantizers] QuantFusion(patterns)(converted_model) - # pre-autograd export. eventually this will become torch.export - converted_model_exp = capture_pre_autograd_graph(converted_model, example_inputs) + # Get edge program (note: the name will change to export_to_xtensa in future PRs) + edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True) - converted_model_exp = torch.ao.quantization.move_exported_model_to_eval( - converted_model_exp + # Run a couple required passes for quant/dequant ops + xtensa_prog_manager = edge_prog_manager.transform( + [ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()], + check_ir_validity=False, ) - exec_prog = ( - export_to_edge( - converted_model_exp, - example_inputs, - edge_compile_config=EdgeCompileConfig( - _check_ir_validity=False, - ), - ) - .transform( - [ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()], - check_ir_validity=False, - ) - .to_executorch() - ) + exec_prog = xtensa_prog_manager.to_executorch() - logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}") + logging.info( + f"Final exported graph module:\n{exec_prog.exported_program().graph_module}" + ) # Save the program as XtensaDemoModel.pte save_pte_program(exec_prog, "XtensaDemoModel") diff --git a/examples/xtensa/aot/meta_registrations.py b/examples/xtensa/aot/meta_registrations.py index aa6014dc9cf..d62334fda4b 100644 --- a/examples/xtensa/aot/meta_registrations.py +++ b/examples/xtensa/aot/meta_registrations.py @@ -4,10 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple + import torch from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library +from .utils import get_conv1d_output_size + lib = Library("xtensa", "DEF") lib.define( @@ -25,10 +29,31 @@ ) lib.define( - "quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)" + "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" +) + +lib.define( + "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) + +lib.define( + "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) +lib.define( + "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)") + +lib.define( + "quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) + +lib.define( + "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)" ) lib.define( - "quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" + "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) m = Library("xtensa", "IMPL", "Meta") @@ -58,18 +83,17 @@ def dequantize_per_tensor_meta( return input.new_empty(input.size(), dtype=torch.float) -@impl(m, "quantized_linear_pt2") -def quantized_linear_pt2_meta( +@impl(m, "quantized_linear") +def quantized_linear_meta( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, - in_scale: float, in_zero_point: int, - weight_scale: float, - weight_zero_point: int, - out_multiplier: int, - out_shift: int, + weight_zero_point: torch.Tensor, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, out_zero_point: int, + offset: Optional[torch.Tensor], ): # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] @@ -79,3 +103,58 @@ def quantized_linear_pt2_meta( assert len(weight_size) == 2 out_size[-1] = weight_size[0] return src.new_empty(out_size, dtype=torch.uint8) + + +@impl(m, "quantized_conv") +def quantized_conv_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: torch.Tensor, + bias_scale: torch.Tensor, + output_scale: float, + output_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + channel_last: bool = False, +): + out_channels, _in_channels, *kernel_size = weight.shape + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = get_conv1d_output_size( + in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0] + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@impl(m, "quantized_layer_norm") +def quantized_layer_norm_meta( + input: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +): + return input.new_empty(input.size(), dtype=torch.uint8) + + +@impl(m, "quantized_relu") +def quantized_relu_meta( + X: torch.Tensor, + X_zero_point: torch.Tensor, +): + return X.new_empty(X.size(), dtype=torch.uint8) diff --git a/examples/xtensa/aot/quantizer.py b/examples/xtensa/aot/quantizer.py index 618d374853f..e664e0fb59f 100644 --- a/examples/xtensa/aot/quantizer.py +++ b/examples/xtensa/aot/quantizer.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from math import frexp, isclose, trunc -from typing import Any, Callable, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -24,13 +24,12 @@ QuantizationAnnotation, QuantizationConfig, QuantizationSpec, + SharedQuantizationSpec, ) from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.utils.fuser_utils import legalize_graph -# torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") - def quantize_tensor_multiplier( requantize_scale_tensor: torch.Tensor, @@ -115,6 +114,264 @@ def _no_outside_users(fused_partition) -> bool: return True +# Helper function to get the weight node for both quantized and unquantized weights +# TODO(matthiascremon): get a better test! +def get_weight_node(weights_inputs: fx.Node, dequants_weights: fx.Node) -> fx.Node: + """ + Returns the weight node. + """ + weight_node = ( + weights_inputs + if weights_inputs.name.endswith("_frozen_param") + else dequants_weights + ) + return weight_node + + +# Helper function to get the args and kwargs for the linear replacement op +def get_args_and_kwargs_linear( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + quant_node: fx.Node, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + """ + Returns the args and kwargs for the linear replacement op. + """ + weight_scale = get_weight_node(weights_inputs[0], dequants_weights[0]).args[1] + # pyre-fixme[58]: Unsupported operand types + bias_scale = dequants_inputs[0].args[1] * weight_scale + requantize_scale = bias_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + + # If bias is not available, create a bias tensor with the shape of weight[0] + if not bias_inputs: + weight_node = get_weight_node(weights_inputs[0], dequants_weights[0]).args[0] + # pyre-fixme[16]: Undefined attribute + attr_node = getattr(graph_module, weight_node.target) + weight_shape = list(attr_node.shape) + bias_shape = weight_shape[0] + bias = graph_module.graph.call_function( + torch.ops.aten.full.default, ([bias_shape], 0.0) + ) + else: + bias = bias_inputs[0] + + bias_int32_quant = graph_module.graph.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + bias, + bias_scale, + 0, + -(2**31), + (2**31) - 1, + torch.int32, + ), + ) + + # Create single element tensors for weight_zero_point, out_multiplier, out_shift. + # Note that the function expects int32_t, when it would default to int64_t, so + # we explicitly require that type. + weight_zero_point_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], dequants_weights[0].args[2]), + {"dtype": torch.int32}, + ) + out_multiplier_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_multiplier[0].item()), + {"dtype": torch.int32}, + ) + out_shift_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_shift[0].item()), + {"dtype": torch.int32}, + ) + + args = tuple(inputs_inputs + weights_inputs + other_inputs + [bias_int32_quant]) + kwargs = { + "src_zero_point": dequants_inputs[0].args[2], + "weight_zero_point": weight_zero_point_, + "out_multiplier": out_multiplier_, + "out_shift": out_shift_, + "out_zero_point": quant_node.args[2], + "offset": None, + } + return args, kwargs + + +# Helper function to get the args and kwargs for the layer norm replacement op +def get_args_and_kwargs_layer_norm( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + other_inputs: List[fx.Node], + weights_init_inputs: List[fx.Node], + bias_inputs: List[fx.Node], + quant_node: fx.Node, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + """ + Returns the args and kwargs for the layer norm replacement op. + """ + # Check if the input is per-channel quantized + # TODO(matthiascremon): add proper support and testing for per-channel quantization + assert isinstance(dequants_inputs[0].args[1], float) and isinstance( + dequants_inputs[0].args[2], int + ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" + + # Make the scale and zero_point tensors + scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[1], + ), + ) + zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[2], + ), + ) + + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) + kwargs = { + "normalized_shape": other_inputs[0], + "weight": weights_init_inputs[0], + "bias": bias_inputs[0], + "eps": 1e-05, + "output_scale": quant_node.args[1], + "output_zero_point": quant_node.args[2], + } + return args, kwargs + + +def get_conv_args(arg, first_val: int) -> List[fx.Node]: + return arg if len(arg) == 2 else [first_val, arg[0]] + + +def get_args_and_kwargs_conv1d( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + other_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + quant_node: fx.Node, + op_node: fx.Node, +): + weight_scale = get_weight_node(weights_inputs[0], dequants_weights[0]).args[1] + weight_zero_point = get_weight_node(weights_inputs[0], dequants_weights[0]).args[2] + # pyre-fixme[58]: Unsupported operand types + bias_scale = dequants_inputs[0].args[1] * weight_scale + stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1) + padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0) + dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1) + groups = 1 if len(op_node.args) < 7 else op_node.args[6] + # If bias is not available, create a bias tensor with the shape of weight[0] + if not bias_inputs: + weight_node = get_weight_node(weights_inputs[0], dequants_weights[0]).args[0] + # pyre-fixme[16]: Undefined attribute + attr_node = getattr(graph_module, weight_node.target) + weight_shape = list(attr_node.shape) + bias_shape = weight_shape[0] + bias = graph_module.graph.call_function( + torch.ops.aten.full.default, ([bias_shape], 0.0) + ) + else: + bias = bias_inputs[0] + # The bias is quantized to int32_t + bias_int32_quant = graph_module.graph.call_function( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ( + bias, + bias_scale, + 0, + -(2**31), + (2**31) - 1, + torch.int32, + ), + ) + + # Compute the out multiplier and out shift. They are used when the conv op is + # replaced by quantized linear, we compute them a priori for simplicity but + # may revisit the decision. + requantize_scale = bias_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + + out_multiplier_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_multiplier[0].item()), + {"dtype": torch.int32}, + ) + out_shift_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_shift[0].item()), + {"dtype": torch.int32}, + ) + + # Create a single element tensor for the weight zero point + weight_zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], weight_zero_point), + {"dtype": torch.int32}, + ) + + # Create a single element tensor for the bias scale + bias_scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], bias_scale), + {"dtype": torch.float32}, + ) + + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs + weights_inputs + other_inputs + [bias_int32_quant]) + kwargs = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "groups": groups, + "input_zero_point": dequants_inputs[0].args[2], + "weight_zero_point": weight_zero_point_tensor, + "bias_scale": bias_scale_tensor, + "out_scale": quant_node.args[1], + "out_zero_point": quant_node.args[2], + "out_multiplier": out_multiplier_, + "out_shift": out_shift_, + "channel_last": False, + } + return args, kwargs + + +def get_args_and_kwargs_relu( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], +): + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs) + + X_zero_point = graph_module.graph.call_function( + torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2]) + ) + + kwargs = { + "X_zero_point": X_zero_point, + } + return args, kwargs + + @dataclass class PartitionAnchors: """ @@ -132,7 +389,9 @@ class PartitionAnchors: biases: List[Tuple[fx.Node, int]] = field(default_factory=list) others: List[Tuple[fx.Node, int]] = field(default_factory=list) literals: List[Tuple[fx.Node, int]] = field(default_factory=list) - output: Optional[fx.Node] = None + output: List[Union[Tuple[fx.Node], Tuple[fx.Node, QuantizationSpec]]] = field( + default_factory=list + ) class QuantizationPattern(ABC): @@ -174,11 +433,147 @@ def get_anchors( inputs=[(linear_node, 0)], weights=[(linear_node, 1)], biases=bias, - output=linear_node, + output=[(linear_node,)], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_linear.default + + +class LinearFunctionalPattern(QuantizationPattern): + def partition_types(self): + return [torch.nn.functional.linear] + + def get_anchors( + self, gm: GraphModule, fused_partition: List[GraphModule] + ) -> PartitionAnchors: + linear_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + biases=[(linear_node, 2)], + output=[(linear_node,)], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_linear.default + + +class LayerNormPattern(QuantizationPattern): + def partition_types(self): + return [torch.nn.LayerNorm] + + def get_anchors(self, gm, fused_partition) -> PartitionAnchors: + layer_norm_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(layer_norm_node, 0)], + weights=[(layer_norm_node, 2)], + biases=[(layer_norm_node, 3)], + others=[(layer_norm_node, 1)], + output=[(layer_norm_node,)], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_layer_norm.default + + +class Conv1dPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.Conv1d] + + def get_anchors( + self, gm: GraphModule, fused_partition: List[GraphModule] + ) -> PartitionAnchors: + conv1d_node = fused_partition[0].nodes[-1] + + # If bias is None, replace it with an empty list. + bias = ( + [(conv1d_node, 2)] + if len(conv1d_node.args) > 2 and conv1d_node.args[2] + else [] + ) + + return PartitionAnchors( + inputs=[(conv1d_node, 0)], + weights=[(conv1d_node, 1)], + biases=bias, + output=[(conv1d_node,)], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_conv.default + + +class Conv2dPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.Conv2d] + + def get_anchors( + self, gm: GraphModule, fused_partition: List[GraphModule] + ) -> PartitionAnchors: + conv2d_node = fused_partition[0].nodes[-1] + + # If bias is None, replace it with an empty list. + bias = ( + [(conv2d_node, 2)] + if len(conv2d_node.args) > 2 and conv2d_node.args[2] + else [] + ) + + return PartitionAnchors( + inputs=[(conv2d_node, 0)], + weights=[(conv2d_node, 1)], + biases=bias, + output=[(conv2d_node,)], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_conv.default + + +class AddmmPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.addmm] + + def get_anchors( + self, gm: GraphModule, fused_partition: List[GraphModule] + ) -> PartitionAnchors: + addmm_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(addmm_node, 1)], + weights=[(addmm_node, 2)], + biases=[(addmm_node, 0)], + output=[(addmm_node,)], ) def replacement_op(self): - return torch.ops.xtensa.quantized_linear_pt2.default + return torch.ops.xtensa.quantized_linear.default + + +class ReluPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.ReLU] + + def get_anchors( + self, gm: GraphModule, fused_partition: List[GraphModule] + ) -> PartitionAnchors: + relu_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(relu_node, 0)], + weights=[], + biases=[], + # pyre-fixme[6]: Incompatible parameter type + output=[ + (relu_node, SharedQuantizationSpec((relu_node.args[0], relu_node))) + ], + ) + + def replacement_op(self): + return torch.ops.xtensa.quantized_relu.default class GenericQuantizer(Quantizer): @@ -206,15 +601,21 @@ def annotate(self, model): if not anchors: continue if _is_annotated( - [x[0] for x in anchors.inputs + anchors.weights + anchors.biases] - + [anchors.output] + [ + x[0] + for x in anchors.inputs + + anchors.weights + + anchors.biases + + anchors.output + ] ): continue - anchors.output.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=output_act_qspec, - _annotated=True, - ) + for output, *custom_spec in anchors.output: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=custom_spec[0] if custom_spec else output_act_qspec, + _annotated=True, + ) def annotate_inputs(inputs, spec): for node, idx in inputs: @@ -256,7 +657,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]: ) -class XtensaQuantizer(ComposableQuantizer): +class XtensaBaseQuantizer(ComposableQuantizer): def __init__(self): static_qconfig = QuantizationConfig( act_qspec, @@ -264,9 +665,21 @@ def __init__(self): wgt_qspec, None, ) + static_qconfig_no_wgt = QuantizationConfig( + act_qspec, + act_qspec, + None, + None, + ) super().__init__( [ + GenericQuantizer(AddmmPattern(), static_qconfig), + GenericQuantizer(Conv1dPattern(), static_qconfig), + GenericQuantizer(Conv2dPattern(), static_qconfig), + GenericQuantizer(LayerNormPattern(), static_qconfig_no_wgt), + GenericQuantizer(LinearFunctionalPattern(), static_qconfig), GenericQuantizer(LinearPattern(), static_qconfig), + GenericQuantizer(ReluPattern(), static_qconfig), ] ) @@ -276,7 +689,7 @@ def __init__(self, patterns): super().__init__() self.patterns = patterns - def call(self, graph_module: fx.GraphModule) -> PassResult: + def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 for pattern in self.patterns: fused_partitions = find_sequential_partitions( graph_module, @@ -309,82 +722,81 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: inputs_inputs = [node.args[0] for node in dequants_inputs] weights_inputs = [node.args[0] for node in dequants_weights] + weights_init_inputs = [node.args[idx] for node, idx in anchors.weights] bias_inputs = [node.args[idx] for node, idx in anchors.biases] other_inputs = [node.args[idx] for node, idx in anchors.others] - assert len(anchors.output.users) == 1 - quant_node = list(anchors.output.users.keys())[0] + # The node is the first index of the list and first of the tuple + op_node = anchors.output[0][0] + + assert len(op_node.users) == 1 + quant_node = list(op_node.users.keys())[0] - with graph_module.graph.inserting_after(anchors.output): + with graph_module.graph.inserting_after(op_node): args = tuple( inputs_inputs + weights_inputs + other_inputs + bias_inputs ) kwargs = {} - if ( - pattern.replacement_op() - == torch.ops.xtensa.quantized_linear_pt2.default + if isinstance(pattern, Conv1dPattern) or isinstance( + pattern, Conv2dPattern ): - weight_scale = ( - weights_inputs[0].args[1] - if weights_inputs[0].name[:13] != "_frozen_param" - else dequants_weights[0].args[1] - ) - bias_scale = inputs_inputs[0].args[1] * weight_scale - requantize_scale = bias_scale / quant_node.args[1] - requantize_scale_t = torch.tensor([requantize_scale]) - - (out_multiplier, out_shift) = quantize_tensor_multiplier( - requantize_scale_t + args, kwargs = get_args_and_kwargs_conv1d( + graph_module, + inputs_inputs, + dequants_inputs, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + op_node, ) - bias_shape = weights_inputs - node = ( - weights_inputs[0].args[0] - if weights_inputs[0].name[:13] != "_frozen_param" - else dequants_weights[0].args[0] - ) - attr_node = getattr(graph_module, node.target) - weight_shape = list(attr_node.shape) - bias_shape = weight_shape[0] - bias = ( - bias_inputs[0] - if bias_inputs - else graph_module.graph.call_function( - torch.ops.aten.full.default, ([bias_shape], 0.0) - ) + elif isinstance(pattern, LinearPattern) or isinstance( + pattern, LinearFunctionalPattern + ): + args, kwargs = get_args_and_kwargs_linear( + graph_module, + inputs_inputs, + dequants_inputs, + other_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, ) - bias_int32_quant = graph_module.graph.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ( - bias, - bias_scale, - 0, - -(2**31), - (2**31) - 1, - torch.int32, - ), + elif isinstance(pattern, LayerNormPattern): + args, kwargs = get_args_and_kwargs_layer_norm( + graph_module, + inputs_inputs, + dequants_inputs, + other_inputs, + weights_init_inputs, + bias_inputs, + quant_node, ) - - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, ([1], out_multiplier[0].item()) + elif isinstance(pattern, AddmmPattern): + # Transpose the weight tensor + transposed_weights = graph_module.graph.call_function( + torch.ops.aten.transpose.int, + (weights_inputs[0], 0, 1), ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, ([1], out_shift[0].item()) + # Call linear with transposed weight + args, kwargs = get_args_and_kwargs_linear( + graph_module, + inputs_inputs, + dequants_inputs, + other_inputs, + [transposed_weights], + dequants_weights, + bias_inputs, + quant_node, ) - args = tuple( - inputs_inputs - + weights_inputs - + other_inputs - + [bias_int32_quant] + elif isinstance(pattern, ReluPattern): + args, kwargs = get_args_and_kwargs_relu( + graph_module, + inputs_inputs, + dequants_inputs, ) - kwargs = { - "src_scale": dequants_inputs[0].args[1], - "src_zero_point": dequants_inputs[0].args[2], - "weight_scale": dequants_weights[0].args[1], - "weight_zero_point": dequants_weights[0].args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, - "out_zero_point": quant_node.args[2], - } fused = graph_module.graph.call_function( pattern.replacement_op(), args, diff --git a/examples/xtensa/aot/utils.py b/examples/xtensa/aot/utils.py new file mode 100644 index 00000000000..73f863eed9f --- /dev/null +++ b/examples/xtensa/aot/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +# Get the output size of a 1D convolution given the input size and parameters +def get_conv1d_output_size( + in_size: torch.Size, + out_channels: int, + stride: int, + padding: int, + dilation: int, + kernel_size: int, +) -> torch.Size: + assert len(in_size) == 3 + N, C, L = in_size + + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 + + return torch.Size((in_size[0], out_channels, lout)) diff --git a/examples/xtensa/kernels/kernels.cpp b/examples/xtensa/kernels/kernels.cpp index 5fcc545a540..2f29b25ac82 100644 --- a/examples/xtensa/kernels/kernels.cpp +++ b/examples/xtensa/kernels/kernels.cpp @@ -16,6 +16,11 @@ namespace impl { namespace HiFi { namespace kernels { +__attribute__((always_inline)) void +memcpy(void* dst, const void* src, size_t num_bytes) { + MEMCPY_8b(dst, src, num_bytes); +} + // Quantize a fp32 value to an int8_t/uint8_t value template __attribute__((always_inline)) T diff --git a/examples/xtensa/kernels/kernels.h b/examples/xtensa/kernels/kernels.h index 6a5a255c0ad..13e0470b382 100644 --- a/examples/xtensa/kernels/kernels.h +++ b/examples/xtensa/kernels/kernels.h @@ -16,6 +16,8 @@ namespace impl { namespace HiFi { namespace kernels { +void memcpy(void* dst, const void* src, size_t num_bytes); + WORD32 matmul_asym8uxasym8u_asym8u( UWORD8* __restrict__ p_out, // output uint8 matrix const UWORD8* __restrict__ p_mat1, // weight uint8 matrix @@ -35,6 +37,12 @@ WORD32 matmul_asym8uxasym8u_asym8u( WORD32 out_zero_bias, bool per_channel_quantized = false); // per-channel quantized weight +template +T quantize(const float x, float scale, int32_t zero_point); + +template +float dequantize(const T x, float scale, int32_t zero_point); + template void quantize( T* __restrict__ y, diff --git a/examples/xtensa/ops/CMakeLists.txt b/examples/xtensa/ops/CMakeLists.txt index 215de49f20c..abcac83d283 100644 --- a/examples/xtensa/ops/CMakeLists.txt +++ b/examples/xtensa/ops/CMakeLists.txt @@ -26,11 +26,14 @@ endif() # ATen compliant ops that are needed to run this model. set(_aten_ops__srcs "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_embedding.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/op_full.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_view_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp") add_library(aten_ops_xtensa ${_aten_ops__srcs}) target_link_libraries(aten_ops_xtensa PUBLIC executorch) +target_link_libraries(aten_ops_xtensa PRIVATE xtensa_kernels) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) @@ -41,8 +44,9 @@ target_include_directories(aten_ops_xtensa PUBLIC ${ROOT_DIR}/.. # Custom ops that are needed to run the test model. add_library( - custom_ops "quantized_linear_out.cpp" "quantize_per_tensor.cpp" - "dequantize_per_tensor.cpp") + custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" + "quantized_relu_out.cpp" "quantized_layer_norm.cpp" + "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} ${_common_include_directories}) diff --git a/examples/xtensa/ops/functions.yaml b/examples/xtensa/ops/functions.yaml index 07093d3ed24..91ffd18063a 100644 --- a/examples/xtensa/ops/functions.yaml +++ b/examples/xtensa/ops/functions.yaml @@ -11,20 +11,33 @@ # by this file. +# aten ops - op: add.out kernels: - arg_meta: null kernel_name: torch::executor::add_out +- op: embedding.out + kernels: + - arg_meta: null + kernel_name: torch::executor::embedding_out + - op: full.out kernels: - arg_meta: null kernel_name: torch::executor::full_out -- func: xtensa::quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) +- op: view_copy.out kernels: - arg_meta: null - kernel_name: impl::HiFi::quantized_linear_pt2_out + kernel_name: torch::executor::view_copy_out + +# custom ops +- func: xtensa::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_out - func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -32,8 +45,17 @@ - arg_meta: null kernel_name: impl::HiFi::dequantize_per_tensor_out -- func: xtensa::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) - variants: function +- func: xtensa::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null - kernel_name: impl::HiFi::quantize_per_tensor_out + kernel_name: impl::HiFi::quantized_conv_out + +- func: xtensa::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_linear_out + +- func: xtensa::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantized_relu_out diff --git a/examples/xtensa/ops/op_embedding.cpp b/examples/xtensa/ops/op_embedding.cpp new file mode 100644 index 00000000000..b4100feacc1 --- /dev/null +++ b/examples/xtensa/ops/op_embedding.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +void embedding_out( + RuntimeContext& ctx, + const Tensor& weight, + const Tensor& indices, + int64_t padding_idx, + bool scale_grad_by_freq, + bool sparse, + Tensor& out) { + int64_t nbytes_per_entry = weight.size(1) * weight.element_size(); + const char* w_data = weight.const_data_ptr(); + char* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + for (int i = 0, e = indices.numel(); i < e; i++) { + // memcpy(dest, src, nbytes); + impl::HiFi::kernels::memcpy( + out_data, w_data + nbytes_per_entry * indices_ptr[i], nbytes_per_entry); + out_data += nbytes_per_entry; + } +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/examples/xtensa/ops/op_view_copy.cpp b/examples/xtensa/ops/op_view_copy.cpp new file mode 100644 index 00000000000..e856c1592cb --- /dev/null +++ b/examples/xtensa/ops/op_view_copy.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +Tensor& view_copy_out( + RuntimeContext& ctx, + const Tensor& input, + const IntArrayRef size, + Tensor& out) { + impl::HiFi::kernels::memcpy( + out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes()); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/examples/xtensa/ops/quantized_conv_out.cpp b/examples/xtensa/ops/quantized_conv_out.cpp new file mode 100644 index 00000000000..23e189e6bcb --- /dev/null +++ b/examples/xtensa/ops/quantized_conv_out.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "kernels.h" + +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +// This implements a generic 2d conv kernel that operates on raw pointers. +// The version handles both quantized and fp32 convolutions. +// The input is of shape [n x c x h x w] +// The weight is of shape [oc x wc x wh x ww], where wc == c +// The output is of shape [n x oc x oh x ow] +// The bias is of shape [oc] +template +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + const int32_t* __restrict__ weight_zero_point = nullptr, + const float* __restrict__ bias_scale = nullptr, + float out_scale = 1, + OT out_zero_point = 0, + bool per_tensor_quantized = true) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point[0] : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1 < w))) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point[0] : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = + (per_tensor_quantized ? bias_scale[0] : bias_scale[_oc]) * + acc; + out_plane[_oh * ow + _ow] = + kernels::quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +// The quantized convolution kernel. in_scale and weight_scale are implicit in +// bias_scale, since it is a product of the two. The kernel will branch to +// quantized::conv1d or quantized::conv2d based on the dimensionality of +// activation tensor. +void quantized_conv_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + exec_aten::IntArrayRef stride, + exec_aten::IntArrayRef padding, + exec_aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + bool channel_last, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + // Bool flag to check if weight tensor is quantized per-tensor or + // per-channel + bool per_tensor_quantized = bias_scale.numel() == 1; + + conv2d_nchw_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + in_zero_point, + weight_zero_point.const_data_ptr(), + bias_scale.const_data_ptr(), + output_scale, + (uint8_t)output_zero_point, + per_tensor_quantized); +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/examples/xtensa/ops/quantized_layer_norm.cpp b/examples/xtensa/ops/quantized_layer_norm.cpp new file mode 100644 index 00000000000..27d86e56227 --- /dev/null +++ b/examples/xtensa/ops/quantized_layer_norm.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +#include +#include +#include + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +namespace impl { +namespace HiFi { +namespace native { + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + float input_scale, + int64_t input_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Get the raw pointers to input, output, weight, and bias + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + const float* __restrict__ weight_data = weight.const_data_ptr(); + const float* __restrict__ bias_data = bias.const_data_ptr(); + + float output_inv_scale = XT_RECIP_S(output_scale); + + size_t last_dim = input.size(input.dim() - 1); + size_t leading_dims = getLeadingDims(input, input.dim() - 1); + + // Visualize the input tensor as a set of 1d vectors, and compute the + // layer_norm for each vector. + for (size_t i = 0; i < leading_dims; ++i) { + const T* __restrict__ x = in_data + i * last_dim; + T* __restrict__ y = out_data + i * last_dim; + + // compute sum and squared sum. The fp32 sum can be approximated as: + // (X_1 - in_zero_point) * in_scale + (X_2 - in_zero_point) * in_scale + ... + // (X_N - in_zero_point) * in_scale. + int32_t sum = 0; + int32_t sq_sum = last_dim * input_zero_point * input_zero_point; +#pragma simd + for (size_t j = 0; j < last_dim; ++j) { + int32_t val = x[j]; + sum += val; + sq_sum += val * val; + } + sq_sum -= (2 * sum * input_zero_point); + sum -= (last_dim * input_zero_point); + + float mean = XT_DIV_S(XT_MUL_S(input_scale, sum), last_dim); + float variance = + XT_DIV_S( + XT_MUL_S(sq_sum, XT_MUL_S(input_scale, input_scale)), last_dim) - + XT_MUL_S(mean, mean); + float inv_std = XT_RECIP_S(XT_SQRT_S(XT_ADD_S(variance, (float)eps))); + + // y = (x - mean) / std * kGamma + kBeta +#pragma simd + for (size_t j = 0; j < last_dim; ++j) { + // Since X is quantized, we dequantize it, compute fp32 result, and + // quantize the result to an int8/uint8 value. + float val = kernels::dequantize(x[j], input_scale, input_zero_point); + val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; + y[j] = kernels::quantize(val, output_inv_scale, output_zero_point); + } + } +} + +// Compute quantized layer_norm. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_layer_norm_( + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + + // Call other overload + quantized_layer_norm_( + input, + input_scale, + input_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); +} + +void quantized_layer_norm_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& in_scale, + const Tensor& in_zero_point, + const exec_aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == exec_aten::ScalarType::Byte) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == exec_aten::ScalarType::Char) { + quantized_layer_norm_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + } +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/examples/xtensa/ops/quantized_linear_out.cpp b/examples/xtensa/ops/quantized_linear_out.cpp index 2acf36e33d8..7aa488f282c 100644 --- a/examples/xtensa/ops/quantized_linear_out.cpp +++ b/examples/xtensa/ops/quantized_linear_out.cpp @@ -19,18 +19,7 @@ namespace native { using Tensor = exec_aten::Tensor; using RuntimeContext = torch::executor::RuntimeContext; -namespace linear_util { -// This function compute the product of dim[0:dim] where dim is not inclusive -size_t getLeadingDims(const Tensor& tensor, int64_t dim) { - size_t dims = 1; - for (size_t i = 0; i < dim; ++i) { - dims *= tensor.size(i); - } - return dims; -} -} // namespace linear_util - -void quantized_linear_pt2_out( +void quantized_linear_out( RuntimeContext& ctx, const Tensor& src, const Tensor& weight, @@ -47,7 +36,7 @@ void quantized_linear_pt2_out( // weight comes in shape [out_dim, in_dim] // output comes in empty with shape [leading_dims, out_dim] // Perform matrix multiply (M x N) x (N x P)' => M x P - int64_t leading_dims = linear_util::getLeadingDims(src, src.dim() - 1); + int64_t leading_dims = getLeadingDims(src, src.dim() - 1); int64_t out_dim = weight.size(0); // = out_dim int64_t in_dim = weight.size(1); // = in_dim diff --git a/examples/xtensa/ops/quantized_relu_out.cpp b/examples/xtensa/ops/quantized_relu_out.cpp new file mode 100644 index 00000000000..1643747baec --- /dev/null +++ b/examples/xtensa/ops/quantized_relu_out.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include "kernels.h" + +namespace impl { +namespace HiFi { +namespace native { + +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + +// Note: this kernel assumes that the input and output share quantization +// parameters. If that is not the case, it will produce incorrect results. +template +void quantized_relu_( + const Tensor& input, + const Tensor& in_zero_point, + Tensor& output) { + T q_zero_point = in_zero_point.const_data_ptr()[0]; + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + out[i] = in[i] > q_zero_point ? in[i] : q_zero_point; + } +} + +void quantized_relu_out( + RuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + Tensor& output) { + if (input.scalar_type() == exec_aten::ScalarType::Byte) { + quantized_relu_(input, in_zero_point, output); + } else if (input.scalar_type() == exec_aten::ScalarType::Char) { + quantized_relu_(input, in_zero_point, output); + } else { + ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); + } +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/examples/xtensa/tests/quantized_conv1d_example.py b/examples/xtensa/tests/quantized_conv1d_example.py new file mode 100644 index 00000000000..aa29c85c166 --- /dev/null +++ b/examples/xtensa/tests/quantized_conv1d_example.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from ..aot.meta_registrations import * # noqa + +import torch + +from ..aot.export_example import export_xtensa_model + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + ( + shape, + in_channels, + out_channels, + kernel, + stride, + padding, + dilation, + depthwise, + bias, + channel_last, + ) = [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, True, False] + + class QuantizedConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d( + in_channels, + out_channels, + kernel, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels if depthwise else 1, + bias=bias, + ) + + def forward(self, x: torch.Tensor): + return self.conv1d(x) + + model = QuantizedConv() + model.eval() + + example_inputs = (torch.randn(shape),) + + export_xtensa_model(model, example_inputs) diff --git a/examples/xtensa/tests/quantized_linear_example.py b/examples/xtensa/tests/quantized_linear_example.py new file mode 100644 index 00000000000..410a825ee92 --- /dev/null +++ b/examples/xtensa/tests/quantized_linear_example.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from ..aot.meta_registrations import * # noqa + +import torch + +from ..aot.export_example import export_xtensa_model + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + in_features = 32 + out_features = 16 + bias = True + shape = [64, in_features] + + class QuantizedLinear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool): + super().__init__() + self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor): + output_linear_out = self.output_linear(x) + return output_linear_out + + model = QuantizedLinear(in_features, out_features, bias) + model.eval() + + example_inputs = (torch.ones(shape),) + + export_xtensa_model(model, example_inputs) diff --git a/examples/xtensa/tests/rnnt_predictor_quantized_example.py b/examples/xtensa/tests/rnnt_predictor_quantized_example.py new file mode 100644 index 00000000000..cfc1531cb7c --- /dev/null +++ b/examples/xtensa/tests/rnnt_predictor_quantized_example.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +import torch + +from ..aot.meta_registrations import * # noqa + +from typing import Tuple + +from ..aot.export_example import export_xtensa_model + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +if __name__ == "__main__": + + class Predictor(torch.nn.Module): + def __init__( + self, + num_symbols: int, + symbol_embedding_dim: int, + ) -> None: + super().__init__() + self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim) + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(symbol_embedding_dim, symbol_embedding_dim) + self.layer_norm = torch.nn.LayerNorm(symbol_embedding_dim) + + def forward( + self, + input: torch.Tensor, + lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_tb = input.permute(1, 0) + embedding_out = self.embedding(input_tb) + relu_out = self.relu(embedding_out) + linear_out = self.linear(relu_out) + layer_norm_out = self.layer_norm(linear_out) + return layer_norm_out.permute(1, 0, 2), lengths + + # Predictor + model = Predictor(128, 256) + model.eval() + + # Batch size + batch_size = 1 + + num_symbols = 128 + max_target_length = 10 + + # Dummy inputs + predictor_input = torch.randint(0, num_symbols, (batch_size, max_target_length)) + predictor_lengths = torch.randint(1, max_target_length + 1, (batch_size,)) + + example_inputs = ( + predictor_input, + predictor_lengths, + ) + + export_xtensa_model(model, example_inputs)