1+ import warnings
12from typing import Dict
23
34import executorch .backends .qualcomm .python .PyQnnWrapperAdaptor as PyQnnWrapper
5+ import numpy as np
46import torch
57
6- from executorch .backends .qualcomm .utils .constants import QCOM_QUANT_ATTRS
8+ from executorch .backends .qualcomm .utils .constants import QCOM_DATA , QCOM_QUANT_ATTRS
9+ from executorch .exir .dialects ._ops import ops as exir_ops
710
8- from .node_visitor import NodeVisitor
11+ from .node_visitor import NodeVisitor , QNN_TENSOR_TYPE_MAP
912from .node_visitor_manager import register_node_visitor
10- from .qnn_constants import OpScatterNd , QNN_OP_PACKAGE_NAME_QTI_AISW
13+ from .qnn_constants import (
14+ OpConcat ,
15+ OpReshape ,
16+ OpScatterNd ,
17+ OpTile ,
18+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
19+ )
1120
1221
1322@register_node_visitor
@@ -22,6 +31,7 @@ def define_node(
2231 node : torch .fx .Node ,
2332 nodes_to_wrappers : Dict [torch .fx .Node , PyQnnWrapper .TensorWrapper ],
2433 ) -> PyQnnWrapper .PyQnnOpWrapper :
34+ op_wrapper_list = []
2535 input_node = self .get_node (node .args [0 ])
2636 # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
2737 if quant_attrs := node .meta .get (QCOM_QUANT_ATTRS ):
@@ -35,38 +45,206 @@ def define_node(
3545 PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
3646 nodes_to_wrappers ,
3747 )
38- indicies_node = node .args [1 ]
39- indices_list = [
40- self .get_tensor (idx , idx ) for idx in indicies_node if idx is not None
41- ]
42-
43- # Unpack the tuple
44- indices_unpacked = [torch .flatten (idx ) for idx in indices_list ]
45-
46- # Convert to 2-D tensor
47- indices_qnn = torch .cat (indices_unpacked ).unsqueeze (0 )
48- indice_node = [n for n in indicies_node if isinstance (n , torch .fx .Node )]
49- # TODO consider to write a pass to combine to one input tensor for indices
50- assert len (indice_node ) == 1 , "Not support multiple indices tensor"
5148
49+ indicies_node = node .args [1 ]
50+ index_node_dim = None
51+ index_nodes = []
52+ index_tensors = []
53+ target_index = []
54+ # If there is None in a list, it means all range at that dimension
55+ # E.g., indicies_node: [None, None, aten__to_copy_default_1]
56+ if isinstance (indicies_node , list ):
57+ for index , idx_node in enumerate (indicies_node ):
58+ # First, collect the indice_node and index of None to construct the shape of index node
59+ # E.g., shape of input: [1, 1024, 12, 64]
60+ # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]),
61+ # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2
62+ if isinstance (idx_node , torch .fx .Node ):
63+ index_nodes .append (idx_node )
64+ index_tensors .append (self .get_tensor (idx_node , idx_node ))
65+ target_index .extend (index_tensors [- 1 ].size ())
66+ index_node_dim = index
67+ elif idx_node is None and index_node_dim is None :
68+ # E.g., indicies_node: [None, aten__to_copy_default_1, None]
69+ # Don't need to consider "None" after index_node.
70+ target_index .append (input_tensor .size (index ))
71+ else :
72+ warnings .warn (
73+ f"[QNN Delegate Op Builder]: Get the index { idx_node } that is neither a node nor None" ,
74+ stacklevel = 1 ,
75+ )
76+ return
77+ # Assume that there is only one node in list
78+ assert len (index_nodes ) == 1 , "Not support multiple indices tensor"
79+ indice_node = index_nodes [0 ]
80+ indice_tensor = index_tensors [0 ]
5281 indices_tensor_wrapper = self .define_tensor (
53- indice_node [ 0 ] ,
82+ indice_node ,
5483 node ,
55- indices_qnn ,
84+ indice_tensor ,
5685 PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
5786 nodes_to_wrappers ,
5887 )
59- value_node = self .get_node (node .args [2 ])
6088
61- value_tensor = self .get_tensor (value_node , node )
89+ # Need to reconstruct the index tensor.
90+ # E.g., based on ScatterND Op Def in QNN Docs.
91+ # Given that
92+ # shape of input: [1, 12, 1024, 64]
93+ # indicies_node: [None, None, aten__to_copy_default_1]
94+ # shape of aten__to_copy_default_1: [1]
95+ # The shape of index tensor should be [1, 12, 1, 3]
96+ # The index tensor is treated as 4-dimensional tensor of 3-tuples,
97+ # where each 3-tuple is a partial-index into input
98+ # Reference code for QNN ScatterNd:
99+ # output = np.copy(input)
100+ # update_indices = indices.shape[:-1]
101+ # for idx in np.ndindex(update_indices):
102+ # output[indices[idx]] = updates[idx]
103+
104+ # Append one dimension to specify x-tuple
105+ index_shape = target_index + [1 ]
106+ # Reshape the index_node for tile op
107+ reshape_shape = [
108+ shape if id == index_node_dim else 1 for id , shape in enumerate (index_shape )
109+ ]
110+ reshape_output_tensor = indice_tensor .reshape (reshape_shape )
111+ reshape_output_tensor_wrapper = self .define_custom_tensor_wrapper (
112+ node_name = node .name + "_reshape" ,
113+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
114+ dtype = QNN_TENSOR_TYPE_MAP [reshape_output_tensor .dtype ],
115+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
116+ quant_configs = {},
117+ dims = reshape_output_tensor .size (),
118+ tensor = reshape_output_tensor ,
119+ is_fake_tensor = True ,
120+ nodes_to_wrappers = nodes_to_wrappers ,
121+ )
122+ reshape_op = PyQnnWrapper .PyQnnOpWrapper (
123+ node .name ,
124+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
125+ OpReshape .op_name ,
126+ )
127+ reshape_op .AddInputTensors ([indices_tensor_wrapper ])
128+ reshape_op .AddOutputTensors ([reshape_output_tensor_wrapper ])
129+ op_wrapper_list .append (reshape_op )
130+ index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper
131+
132+ # Tile the index_node and concat the target index
133+ if None in indicies_node :
134+ tile_output_tensor = reshape_output_tensor .expand (index_shape )
135+ # Tile the index_node to align with the shape of target_index
136+ # Only need to tile the dim of None axis
137+ # E.g., indicies_node: [None, None, aten__to_copy_default_1]
138+ # Should tile the first two dimension.
139+ multiples = [
140+ shape if id != index_node_dim else 1
141+ for id , shape in enumerate (index_shape )
142+ ]
143+ multiples_shape = [len (index_shape )]
144+ tile_output_tensor_wrapper = self .define_custom_tensor_wrapper (
145+ node_name = node .name + "_tile" ,
146+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
147+ dtype = QNN_TENSOR_TYPE_MAP [tile_output_tensor .dtype ],
148+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
149+ quant_configs = {},
150+ dims = tile_output_tensor .size (),
151+ tensor = tile_output_tensor ,
152+ is_fake_tensor = True ,
153+ nodes_to_wrappers = nodes_to_wrappers ,
154+ )
155+ tile_op = PyQnnWrapper .PyQnnOpWrapper (
156+ node .name ,
157+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
158+ OpTile .op_name ,
159+ )
160+ tile_op .AddInputTensors ([reshape_output_tensor_wrapper ])
161+ tile_op .AddOutputTensors ([tile_output_tensor_wrapper ])
162+ tile_op .AddTensorParam (
163+ OpTile .param_multiples ,
164+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
165+ len (multiples_shape ),
166+ multiples_shape ,
167+ np .array (multiples , dtype = np .uint32 ),
168+ True ,
169+ )
170+ op_wrapper_list .append (tile_op )
171+
172+ # Repeat index for "None" axis in indicies_node
173+ ranges = [
174+ torch .arange (dim , dtype = indice_tensor .dtype )
175+ for dim in target_index [:- 1 ]
176+ ]
177+ target_index_shape = target_index + [len (ranges )]
178+ target_index_tensor = torch .cartesian_prod (* ranges )
179+ reshape_target_index_shape = [
180+ shape if id != index_node_dim else 1
181+ for id , shape in enumerate (target_index_shape )
182+ ]
183+ target_index_tensor = target_index_tensor .reshape (
184+ reshape_target_index_shape
185+ )
186+ target_index_tensor = target_index_tensor .expand (
187+ target_index_shape
188+ ).contiguous ()
189+ target_index_node = torch .fx .Node (
190+ node .graph ,
191+ node .name + "_target_index" ,
192+ "call_function" ,
193+ exir_ops .edge .aten .tensor .default ,
194+ (), # args
195+ {}, # kwargs
196+ )
197+ target_index_tensor_wrapper = self .define_tensor (
198+ target_index_node ,
199+ node ,
200+ target_index_tensor ,
201+ PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
202+ nodes_to_wrappers ,
203+ )
62204
205+ # Concat target_index and tile output to reconstruct index_node
206+ # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype
207+ concat_output_tensor = torch .concat (
208+ (target_index_tensor , tile_output_tensor ), dim = - 1
209+ )
210+ concat_output_tensor_wrapper = self .define_custom_tensor_wrapper (
211+ node_name = node .name + "_concat" ,
212+ tensor_type = PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
213+ dtype = QNN_TENSOR_TYPE_MAP [concat_output_tensor .dtype ],
214+ quant_encoding = PyQnnWrapper .Qnn_QuantizationEncoding_t .QNN_QUANTIZATION_ENCODING_UNDEFINED ,
215+ quant_configs = {},
216+ dims = concat_output_tensor .size (),
217+ tensor = concat_output_tensor ,
218+ is_fake_tensor = True ,
219+ nodes_to_wrappers = nodes_to_wrappers ,
220+ )
221+ concat_op = PyQnnWrapper .PyQnnOpWrapper (
222+ node .name ,
223+ QNN_OP_PACKAGE_NAME_QTI_AISW ,
224+ OpConcat .op_name ,
225+ )
226+ concat_op .AddInputTensors (
227+ [target_index_tensor_wrapper , tile_output_tensor_wrapper ]
228+ )
229+ concat_op .AddOutputTensors ([concat_output_tensor_wrapper ])
230+ concat_op .AddScalarParam (
231+ OpConcat .param_axis ,
232+ PyQnnWrapper .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
233+ {QCOM_DATA : np .uint32 (concat_output_tensor .dim () - 1 )},
234+ )
235+ op_wrapper_list .append (concat_op )
236+ index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper
237+
238+ value_node = self .get_node (node .args [2 ])
239+ value_tensor = self .get_tensor (value_node , node )
63240 value_tensor_wrapper = self .define_tensor (
64241 value_node ,
65242 node ,
66243 value_tensor ,
67244 PyQnnWrapper .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
68245 nodes_to_wrappers ,
69246 )
247+
70248 output_tensor = self .get_tensor (node , node )
71249 output_tensor_wrapper = self .define_tensor (
72250 node ,
@@ -82,8 +260,12 @@ def define_node(
82260 OpScatterNd .op_name ,
83261 )
84262 index_put_op .AddInputTensors (
85- [input_tensor_wrapper , indices_tensor_wrapper , value_tensor_wrapper ]
263+ [
264+ input_tensor_wrapper ,
265+ index_put_index_input_tensor_wrapper ,
266+ value_tensor_wrapper ,
267+ ]
86268 )
87269 index_put_op .AddOutputTensors ([output_tensor_wrapper ])
88-
89- return index_put_op
270+ op_wrapper_list . append ( index_put_op )
271+ return op_wrapper_list
0 commit comments