22from typing import Optional , Sequence , Union , cast
33
44import numpy as np
5+ import torch
56import tensorrt as trt
67from torch .fx .node import Target
78from torch_tensorrt .dynamo ._SourceIR import SourceIR
@@ -87,6 +88,8 @@ def index(
8788 # _LOGGER.debug(f"The index shape is {index.shape}")
8889 # check if the input is dynamic
8990 dynamic_shape = has_dynamic_shape (input .shape )
91+ #is_numpy is a flag to specify if input isa numpy
92+ is_numpy = False
9093
9194 # here we need to check if all the index are broadcastable
9295 # if no, then we need to broadcast
@@ -95,8 +98,14 @@ def index(
9598 if ind is not None :
9699 _LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
97100 adv_indx_indices .append (i )
98- # torch.nn.parameter.Parameter=> torch.Tensor
99- ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
101+ # torch.nn.parameter.Parameter=> numpy array
102+ # numpy array is kept as numpy
103+ # other cases are kept as TRTTensor
104+ if (isinstance (ind , torch .Tensor ) or (ind , np .ndarray )):
105+ ind = to_numpy (ind )
106+ is_numpy = True
107+ else :
108+ ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
100109 if last_index is not None :
101110 assert broadcastable (
102111 ind , last_index
@@ -131,8 +140,11 @@ def index(
131140 for i in range (rank ):
132141 dim = input_shape [i ]
133142 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
134- # dim_tensor_list is a list of tensors
135- dim_tensor_list .append (dim_tensor )
143+ # dim_tensor_list is a list of tensors or numpy
144+ if (is_numpy ):
145+ dim_tensor_list .append (dim )
146+ else :
147+ dim_tensor_list .append (dim_tensor )
136148
137149 # for cases like
138150 # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -150,9 +162,14 @@ def index(
150162 if i not in adv_indx_indices :
151163 new_order .append (i )
152164 _LOGGER .debug (f"The new transpose order is { new_order } " )
153- transpose_layer .second_transpose = tuple (new_order )
154- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
155- transpose_tensor = transpose_layer .get_output (0 )
165+
166+ transpose_tensor = None
167+ if (is_numpy ):
168+ transpose_tensor = input [new_order ]
169+ else :
170+ transpose_layer .second_transpose = tuple (new_order )
171+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172+ transpose_tensor = transpose_layer .get_output (0 )
156173
157174 # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
158175 # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -164,58 +181,71 @@ def index(
164181 mult_d1 = 1
165182 for i in range (adv_indx_count , rank ):
166183 mult_d1 = mult_d1 * transpose_tensor_shape [i ]
184+
185+ flatten_tensor = None
186+ if (is_numpy ):
187+ flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188+ else :
189+ concat_tensor_layer = ctx .net .add_concatenation (
190+ [
191+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193+ ]
194+ )
195+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196+ concat_tensor = concat_tensor_layer .get_output (0 )
167197
168- concat_tensor_layer = ctx .net .add_concatenation (
169- [
170- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
171- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
172- ]
173- )
174- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
175- concat_tensor = concat_tensor_layer .get_output (0 )
176-
177- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
178- # check this
179- reshape_layer .set_input (1 , concat_tensor )
180- flatten_tensor = reshape_layer .get_output (0 )
198+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199+ reshape_layer .set_input (1 , concat_tensor )
200+ flatten_tensor = reshape_layer .get_output (0 )
201+
181202 _LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
182203
183204 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184205 # // j dimension of input x.
185- multiplier = get_trt_tensor (
186- ctx ,
187- dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
188- name + "_dim_last" ,
189- )
190- cum_adv_index = tensor_indices [adv_indx_count - 1 ]
191- for i in range (adv_indx_count - 2 , - 1 , - 1 ):
192- adv_index = convert_binary_elementwise (
193- ctx ,
194- target ,
195- source_ir ,
196- name + f"_index_intermediate_{ i } " ,
197- trt .ElementWiseOperation .PROD ,
198- multiplier ,
199- tensor_indices [i ],
200- )
201- cum_adv_index = convert_binary_elementwise (
202- ctx ,
203- target ,
204- source_ir ,
205- name + f"_index_sum_intermediate_{ i } " ,
206- trt .ElementWiseOperation .SUM ,
207- cum_adv_index ,
208- adv_index ,
209- )
210- multiplier = convert_binary_elementwise (
206+ if (is_numpy ):
207+ multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
208+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210+ adv_index = multiplier * tensor_indices [i ]
211+ cum_adv_index = cum_adv_index + adv_index
212+ multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
213+ else :
214+
215+ multiplier = get_trt_tensor (
211216 ctx ,
212- target ,
213- source_ir ,
214- name + f"_index_intermediate_xj_{ i } " ,
215- trt .ElementWiseOperation .PROD ,
216- multiplier ,
217- dim_tensor_list [adv_indx_indices [i ]],
217+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
218+ name + "_dim_last" ,
218219 )
220+ cum_adv_index = tensor_indices [adv_indx_count - 1 ]
221+ for i in range (adv_indx_count - 2 , - 1 , - 1 ):
222+ adv_index = convert_binary_elementwise (
223+ ctx ,
224+ target ,
225+ source_ir ,
226+ name + f"_index_intermediate_{ i } " ,
227+ trt .ElementWiseOperation .PROD ,
228+ multiplier ,
229+ tensor_indices [i ],
230+ )
231+ cum_adv_index = convert_binary_elementwise (
232+ ctx ,
233+ target ,
234+ source_ir ,
235+ name + f"_index_sum_intermediate_{ i } " ,
236+ trt .ElementWiseOperation .SUM ,
237+ cum_adv_index ,
238+ adv_index ,
239+ )
240+ multiplier = convert_binary_elementwise (
241+ ctx ,
242+ target ,
243+ source_ir ,
244+ name + f"_index_intermediate_xj_{ i } " ,
245+ trt .ElementWiseOperation .PROD ,
246+ multiplier ,
247+ dim_tensor_list [adv_indx_indices [i ]],
248+ )
219249
220250 gather_layer_element = ctx .net .add_gather (flatten_tensor , cum_adv_index , 0 )
221251 set_layer_name (
@@ -239,29 +269,36 @@ def index(
239269 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
240270 ):
241271 _LOGGER .debug (f"The indices are continuous in this case" )
242- concat_tensor_reshape .append (
243- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
244- )
272+ if (is_numpy ):
273+ concat_tensor_reshape .append (- 1 )
274+ else :
275+ concat_tensor_reshape .append (
276+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277+ )
245278 for i in range (0 , rank ):
246279 if i not in adv_indx_indices :
247280 curr_dim = dim_tensor_list [i ]
248281 concat_tensor_reshape .append (curr_dim )
282+
283+ unfold_tensor = None
284+ if (is_numpy ):
285+ unfold_tensor = gather_out .reshape (concat_tensor )
286+ else :
287+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288+ set_layer_name (
289+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290+ )
291+ concat_tensor = concat_tensor_layer .get_output (0 )
249292
250- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
251- set_layer_name (
252- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
253- )
254- concat_tensor = concat_tensor_layer .get_output (0 )
255-
256- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
257- regular_index_shuffle_layer .set_input (1 , concat_tensor )
258- set_layer_name (
259- regular_index_shuffle_layer ,
260- target ,
261- name + "_index_regular_index" ,
262- source_ir ,
263- )
264- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
293+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
295+ set_layer_name (
296+ regular_index_shuffle_layer ,
297+ target ,
298+ name + "_index_regular_index" ,
299+ source_ir ,
300+ )
301+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
265302 _LOGGER .debug (f"The tensor is unfolded now" )
266303 _LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
267304
@@ -274,15 +311,19 @@ def index(
274311 for i in range (adv_indx_indices [0 ] + 1 , rank - adv_indx_count + 1 ):
275312 new_order .append (i )
276313 _LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
277-
278- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
279- set_layer_name (
280- transpose_advanced_shuffle_layer ,
281- target ,
282- name + "_index_advanced_shuffle_transpose" ,
283- source_ir ,
284- )
285- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
314+
315+ transpose_tensor = None
316+ if (is_numpy ):
317+ transpose_tensor = unfold_tensor [new_order ]
318+ else :
319+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320+ set_layer_name (
321+ transpose_advanced_shuffle_layer ,
322+ target ,
323+ name + "_index_advanced_shuffle_transpose" ,
324+ source_ir ,
325+ )
326+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
286327
287328 # unfold advanced layer
288329 concat_final_tensor = []
@@ -296,25 +337,29 @@ def index(
296337 current_dim = dim_tensor_list [i ]
297338 concat_final_tensor .append (current_dim )
298339
299- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
300- set_layer_name (
301- concat_final_shape_layer ,
302- target ,
303- name + "_index_continuous_concat_final_shape_layer" ,
304- source_ir ,
305- )
306- concat_final_tensor = concat_final_shape_layer .get_output (0 )
340+ reshape_output = []
341+ if (is_numpy ):
342+ reshape_output = transpose_tensor .reshape (concat_final_tensor )
343+ else :
344+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345+ set_layer_name (
346+ concat_final_shape_layer ,
347+ target ,
348+ name + "_index_continuous_concat_final_shape_layer" ,
349+ source_ir ,
350+ )
351+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
307352
308- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
309- # check this
310- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
311- set_layer_name (
312- unfold_advanced_shuffle_layer ,
313- target ,
314- name + "_unfold_advanced_index" ,
315- source_ir ,
316- )
317- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
353+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354+ # check this
355+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356+ set_layer_name (
357+ unfold_advanced_shuffle_layer ,
358+ target ,
359+ name + "_unfold_advanced_index" ,
360+ source_ir ,
361+ )
362+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
318363
319364 else :
320365 _LOGGER .debug (f"The indices are not continuous in this case" )
@@ -324,24 +369,28 @@ def index(
324369 if i not in adv_indx_indices :
325370 curr_dim = dim_tensor_list [i ]
326371 concat_final_tensor .append (curr_dim )
372+
373+ reshape_output = None
374+ if (is_numpy ):
375+ reshape_output = gather_out .reshape (concat_final_tensor )
376+ else :
377+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378+ set_layer_name (
379+ concat_final_shape_layer ,
380+ target ,
381+ name + "_index_non_continuous_concat_final_shape_layer" ,
382+ source_ir ,
383+ )
384+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
327385
328- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
329- set_layer_name (
330- concat_final_shape_layer ,
331- target ,
332- name + "_index_non_continuous_concat_final_shape_layer" ,
333- source_ir ,
334- )
335- concat_final_tensor = concat_final_shape_layer .get_output (0 )
336-
337- reshape_layer = ctx .net .add_shuffle (gather_out )
338- reshape_layer .set_input (1 , concat_final_tensor )
339- set_layer_name (
340- reshape_layer ,
341- target ,
342- name + "_index_non_continuous_shuffle_final_shape_layer" ,
343- source_ir ,
344- )
345- reshape_output = reshape_layer .get_output (0 )
386+ reshape_layer = ctx .net .add_shuffle (gather_out )
387+ reshape_layer .set_input (1 , concat_final_tensor )
388+ set_layer_name (
389+ reshape_layer ,
390+ target ,
391+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
392+ source_ir ,
393+ )
394+ reshape_output = reshape_layer .get_output (0 )
346395
347396 return reshape_output
0 commit comments