@@ -217,6 +217,12 @@ def getNodeArgs(node):
217217 return [tosa_mapping .TosaArg (arg ) for arg in node .args ]
218218
219219
220+ def getQuantNodeArgs (node ):
221+ quant_args = [tosa_mapping .TosaArg (arg ) for arg in node .args ]
222+ # Return the scale and zp
223+ return quant_args [1 ].number , quant_args [2 ].number
224+
225+
220226@final
221227class ArmBackend (BackendDetails ):
222228 @staticmethod
@@ -253,6 +259,7 @@ def preprocess( # noqa: C901
253259 outp = tosa_mapping .TosaArg (node )
254260
255261 is_quant_node = tosa_quant_utils .isQuantNode (node )
262+
256263 if is_quant_node :
257264 tosa_fb .currRegion .currBasicBlock .addTensor (
258265 outp .name , outp .shape , ts .DType .INT8
@@ -345,13 +352,17 @@ def preprocess( # noqa: C901
345352 elif exir_ops .edge .aten .addmm .default == node .target :
346353 bias , input , weight = inputs
347354
355+ output_dtype = ts .DType .INT8 if is_quant_node else outp .dtype
356+
348357 # Reshape input, weight, bias tensors
349358 input_reshape_res = promote_shape (
350- tosa_fb , input , (1 ,) + input .shape , outp . dtype
359+ tosa_fb , input , (1 ,) + input .shape , output_dtype
351360 )
352361 weight_reshape_res = promote_shape (
353- tosa_fb , weight , (1 ,) + weight .shape , outp . dtype
362+ tosa_fb , weight , (1 ,) + weight .shape , output_dtype
354363 )
364+
365+ bias_dtype = ts .DType .INT32 if is_quant_node else outp .dtype
355366 bias_reshape_res = promote_shape (
356367 tosa_fb ,
357368 bias ,
@@ -360,36 +371,87 @@ def preprocess( # noqa: C901
360371 1 ,
361372 )
362373 + bias .shape ,
363- outp . dtype ,
374+ bias_dtype ,
364375 )
365376
366377 # Add dummy batch 1 to mm_shape
367378 mm_shape = (1 , input .shape [0 ], weight .shape [1 ])
368379 # Define Intermediate tensor for MatMul res
369- mm_res = tosa_fb .addIntermediate (mm_shape , outp .dtype )
380+ mm_res = tosa_fb .addIntermediate (
381+ mm_shape , ts .DType .INT32 if is_quant_node else output_dtype
382+ )
370383
371384 # Add MatMulOp
385+ attr_matmul = ts .TosaSerializerAttribute ()
386+ a_zp , b_zp = (- 128 , 0 ) if is_quant_node else (0 , 0 )
387+ attr_matmul .MatMulAttribute (a_zp , b_zp )
372388 tosa_fb .addOperator (
373389 TosaOp .Op ().MATMUL ,
374390 [input_reshape_res .name , weight_reshape_res .name ],
375391 [mm_res .name ],
376- attr_torch_to_tosa ( TosaOp . Op (). MATMUL , node ) ,
392+ attr_matmul ,
377393 )
378394
379395 # Add AddOp
380- add_res = tosa_fb .addIntermediate (mm_shape , outp .dtype )
396+ add_res = tosa_fb .addIntermediate (
397+ mm_shape , ts .DType .INT32 if is_quant_node else output_dtype
398+ )
399+
381400 tosa_fb .addOperator (
382401 TosaOp .Op ().ADD ,
383402 [bias_reshape_res .name , mm_res .name ],
384403 [add_res .name ],
385404 None ,
386405 )
387406
407+ if is_quant_node :
408+ # Read inputs' parent nodes
409+ #
410+ _ , input_node , weight_node = node .all_input_nodes
411+ input_scale , _ = getQuantNodeArgs (input_node )
412+ weight_node_q_node = weight_node .all_input_nodes [0 ]
413+ weight_scale , _ = getQuantNodeArgs (weight_node_q_node )
414+
415+ consumer_node = list (node .users )[0 ]
416+ consumer_node_scale , consumer_node_node_zp = getQuantNodeArgs (
417+ consumer_node
418+ )
419+
420+ output_rescale_scale = (
421+ input_scale * weight_scale
422+ ) / consumer_node_scale
423+ (
424+ multiplier_output ,
425+ shift_output ,
426+ ) = tosa_quant_utils .computeMultiplierAndShift (
427+ output_rescale_scale
428+ )
429+
430+ attr_rescale_output = ts .TosaSerializerAttribute ()
431+ attr_rescale_output .RescaleAttribute (
432+ input_zp = 0 ,
433+ output_zp = consumer_node_node_zp ,
434+ multiplier = [multiplier_output ],
435+ shift = [shift_output ],
436+ scale32 = True ,
437+ double_round = True ,
438+ per_channel = False ,
439+ )
440+ add_res_int8 = tosa_fb .addIntermediate (mm_shape , ts .DType .INT8 )
441+ tosa_fb .addOperator (
442+ TosaOp .Op ().RESCALE ,
443+ [add_res .name ],
444+ [add_res_int8 .name ],
445+ attr_rescale_output ,
446+ )
388447 # Reshape final result to original shape
389448 attr_out = ts .TosaSerializerAttribute ()
390449 attr_out .ReshapeAttribute (outp .shape )
391450 tosa_fb .addOperator (
392- TosaOp .Op ().RESHAPE , [add_res .name ], [outp .name ], attr_out
451+ TosaOp .Op ().RESHAPE ,
452+ [add_res_int8 .name if is_quant_node else add_res .name ],
453+ [outp .name ],
454+ attr_out ,
393455 )
394456 elif exir_ops .edge .aten .permute_copy .default == node .target :
395457 attr = ts .TosaSerializerAttribute ()
@@ -700,20 +762,11 @@ def preprocess( # noqa: C901
700762 [outp .name ],
701763 attr_mul ,
702764 )
703- elif operator .getitem == node .target :
704- item_name = inputs [0 ].name
705- ## Simply add an identityOp
706- tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
707- elif (
708- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
709- == node .target
710- ):
711- item_name = inputs [0 ].name
712- tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
713- elif (
714- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
715- == node .target
716- ):
765+ elif node .target in [
766+ operator .getitem ,
767+ tosa_quant_utils .q_op ,
768+ tosa_quant_utils .dq_op ,
769+ ]:
717770 item_name = inputs [0 ].name
718771 ## Simply add an identityOp
719772 tosa_fb .addOperator (TosaOp .Op ().IDENTITY , [item_name ], [outp .name ])
@@ -740,9 +793,54 @@ def preprocess( # noqa: C901
740793
741794 assert isinstance (p_data , torch .Tensor ), "Expect Attr to be tensor"
742795 weight_values = p_data .detach ().numpy ()
743- tosa_fb .addConst (
744- inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
745- )
796+
797+ # Check if they're for quantized nodes
798+ consumer_node = list (node .users )[0 ]
799+ if consumer_node .target in tosa_quant_utils .dq_q_ops :
800+ _ , weight_node_scale , weight_node_zp , _ , _ , _ = getNodeArgs (
801+ consumer_node
802+ )
803+
804+ weight_values_quantized = (
805+ (weight_values / weight_node_scale .number )
806+ + weight_node_zp .number
807+ ).astype (np .int8 )
808+ tosa_fb .addConst (
809+ inputs [0 ].shape ,
810+ ts .DType .INT8 ,
811+ weight_values_quantized ,
812+ name = out ,
813+ )
814+ elif (
815+ consumer_node .target == exir_ops .edge .aten .addmm .default
816+ and list (consumer_node .users )[0 ].target == tosa_quant_utils .q_op
817+ ):
818+ (
819+ _ ,
820+ input_node ,
821+ weight_node_permuted ,
822+ ) = consumer_node .all_input_nodes
823+ weight_node = weight_node_permuted .all_input_nodes [0 ]
824+
825+ input_node_scale , _ = getQuantNodeArgs (input_node )
826+ weight_node_scale , weight_node_zp = getQuantNodeArgs (
827+ weight_node
828+ )
829+
830+ weight_values_quantized = (
831+ weight_values / (input_node_scale * weight_node_scale )
832+ ).astype (np .int32 )
833+
834+ tosa_fb .addConst (
835+ inputs [0 ].shape ,
836+ ts .DType .INT32 ,
837+ weight_values_quantized ,
838+ name = out ,
839+ )
840+ else :
841+ tosa_fb .addConst (
842+ inputs [0 ].shape , inputs [0 ].dtype , weight_values , name = out
843+ )
746844 elif out in edge_program .graph_signature .inputs_to_buffers :
747845 parameter_name = edge_program .graph_signature .inputs_to_buffers [
748846 node .name
0 commit comments