diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index d93d470194e..057d3ea93d2 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -20,6 +20,8 @@ from torch.fx import Node from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, @@ -213,6 +215,24 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None _annotated=True, ) + def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: + act_node = node.args[0] + weight_node = node.args[2] + + # TODO current only support 16a16w + annotate_input_qspec_map( + node, + act_node, + quantization_config.input_activation, + ) + + annotate_input_qspec_map( + node, + weight_node, + quantization_config.input_activation, + ) + annotate_output_qspec(node, quantization_config.output_activation) + def annotate_single_in_single_out( node: Node, quantization_config: QuantizationConfig ) -> None: @@ -287,6 +307,9 @@ def annotate_matmul_input1(node: Node): elif node.target == torch.ops.aten.flatten.using_ints: annotate_single_in_share_out(node, quantization_config_8a8w) node = node.args[0] + elif node.target == torch.ops.aten.rms_norm.default: + annotate_rms_norm(node, quantization_config_8a8w) + node = node.args[0] elif node.target == torch.ops.aten.cat.default: annotate_cat(node, quantization_config_8a8w) # For v, we tag 8a until conv op.