@@ -81,16 +81,30 @@ def index(
8181 source_ir : Optional [SourceIR ],
8282 name : str ,
8383 input : TRTTensor ,
84- index : Union [TRTTensor , Sequence [TRTTensor ]],
84+ index : Union [
85+ TRTTensor ,
86+ Sequence [TRTTensor ],
87+ np .ndarray ,
88+ Sequence [np .ndarray ],
89+ torch .Tensor ,
90+ Sequence [torch .Tensor ],
91+ ],
8592) -> TRTTensor :
8693 adv_indx_indices = []
8794 tensor_indices = []
8895 # _LOGGER.debug(f"The index shape is {index.shape}")
8996 # check if the input is dynamic
9097 dynamic_shape = has_dynamic_shape (input .shape )
91- #is_numpy is a flag to specify if input isa numpy
92- is_numpy = False
93-
98+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
99+ # If any is not this flag will be set to False
100+ is_numpy = True
101+ _LOGGER .debug (f"Checking for the is_numpy flag" )
102+ for i , ind in enumerate (index ):
103+ if ind is None :
104+ continue
105+ if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
106+ is_numpy = False
107+ break
94108 # here we need to check if all the index are broadcastable
95109 # if no, then we need to broadcast
96110 last_index = None
@@ -101,7 +115,7 @@ def index(
101115 # torch.nn.parameter.Parameter=> numpy array
102116 # numpy array is kept as numpy
103117 # other cases are kept as TRTTensor
104- if ( isinstance ( ind , torch . Tensor ) or ( ind , np . ndarray )) :
118+ if is_numpy :
105119 ind = to_numpy (ind )
106120 is_numpy = True
107121 else :
@@ -119,8 +133,9 @@ def index(
119133 set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
120134 return identity_layer .get_output (0 )
121135 elif len (tensor_indices ) == 1 :
122- # This case works
123- indices_tensor = tensor_indices [0 ]
136+ indices_tensor = get_trt_tensor (
137+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
138+ )
124139 index = adv_indx_indices [0 ]
125140 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
126141 gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -136,15 +151,15 @@ def index(
136151 rank = len (input_shape )
137152 adv_indx_count = len (adv_indx_indices )
138153 dim_tensor_list = []
154+ dim_list = []
139155
140156 for i in range (rank ):
141157 dim = input_shape [i ]
142158 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
143159 # dim_tensor_list is a list of tensors or numpy
144- if (is_numpy ):
145- dim_tensor_list .append (dim )
146- else :
147- dim_tensor_list .append (dim_tensor )
160+ if is_numpy :
161+ dim_list .append (dim )
162+ dim_tensor_list .append (dim_tensor )
148163
149164 # for cases like
150165 # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -163,13 +178,9 @@ def index(
163178 new_order .append (i )
164179 _LOGGER .debug (f"The new transpose order is { new_order } " )
165180
166- transpose_tensor = None
167- if (is_numpy ):
168- transpose_tensor = input [new_order ]
169- else :
170- transpose_layer .second_transpose = tuple (new_order )
171- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172- transpose_tensor = transpose_layer .get_output (0 )
181+ transpose_layer .second_transpose = tuple (new_order )
182+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
183+ transpose_tensor = transpose_layer .get_output (0 )
173184
174185 # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175186 # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +193,34 @@ def index(
182193 for i in range (adv_indx_count , rank ):
183194 mult_d1 = mult_d1 * transpose_tensor_shape [i ]
184195
185- flatten_tensor = None
186- if (is_numpy ):
187- flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188- else :
189- concat_tensor_layer = ctx .net .add_concatenation (
190- [
191- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193- ]
194- )
195- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196- concat_tensor = concat_tensor_layer .get_output (0 )
196+ concat_tensor_layer = ctx .net .add_concatenation (
197+ [
198+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
199+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
200+ ]
201+ )
202+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
203+ concat_tensor = concat_tensor_layer .get_output (0 )
197204
198- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199- reshape_layer .set_input (1 , concat_tensor )
200- flatten_tensor = reshape_layer .get_output (0 )
205+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
206+ reshape_layer .set_input (1 , concat_tensor )
207+ flatten_tensor = reshape_layer .get_output (0 )
201208
202209 _LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
203210
204211 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205212 # // j dimension of input x.
206- if ( is_numpy ) :
207- multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
213+ if is_numpy :
214+ multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
208215 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209216 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210217 adv_index = multiplier * tensor_indices [i ]
211218 cum_adv_index = cum_adv_index + adv_index
212- multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
219+ multiplier = multiplier * dim_list [adv_indx_indices [i ]]
220+ cum_adv_index = get_trt_tensor (
221+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
222+ )
213223 else :
214-
215224 multiplier = get_trt_tensor (
216225 ctx ,
217226 dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
@@ -269,36 +278,29 @@ def index(
269278 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
270279 ):
271280 _LOGGER .debug (f"The indices are continuous in this case" )
272- if (is_numpy ):
273- concat_tensor_reshape .append (- 1 )
274- else :
275- concat_tensor_reshape .append (
276- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277- )
281+ concat_tensor_reshape .append (
282+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
283+ )
278284 for i in range (0 , rank ):
279285 if i not in adv_indx_indices :
280286 curr_dim = dim_tensor_list [i ]
281287 concat_tensor_reshape .append (curr_dim )
282288
283- unfold_tensor = None
284- if (is_numpy ):
285- unfold_tensor = gather_out .reshape (concat_tensor )
286- else :
287- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288- set_layer_name (
289- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290- )
291- concat_tensor = concat_tensor_layer .get_output (0 )
289+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
290+ set_layer_name (
291+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
292+ )
293+ concat_tensor = concat_tensor_layer .get_output (0 )
292294
293- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294- regular_index_shuffle_layer .set_input (1 , concat_tensor )
295- set_layer_name (
296- regular_index_shuffle_layer ,
297- target ,
298- name + "_index_regular_index" ,
299- source_ir ,
300- )
301- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
295+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
296+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
297+ set_layer_name (
298+ regular_index_shuffle_layer ,
299+ target ,
300+ name + "_index_regular_index" ,
301+ source_ir ,
302+ )
303+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
302304 _LOGGER .debug (f"The tensor is unfolded now" )
303305 _LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
304306
@@ -312,18 +314,14 @@ def index(
312314 new_order .append (i )
313315 _LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
314316
315- transpose_tensor = None
316- if (is_numpy ):
317- transpose_tensor = unfold_tensor [new_order ]
318- else :
319- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320- set_layer_name (
321- transpose_advanced_shuffle_layer ,
322- target ,
323- name + "_index_advanced_shuffle_transpose" ,
324- source_ir ,
325- )
326- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
317+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
318+ set_layer_name (
319+ transpose_advanced_shuffle_layer ,
320+ target ,
321+ name + "_index_advanced_shuffle_transpose" ,
322+ source_ir ,
323+ )
324+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
327325
328326 # unfold advanced layer
329327 concat_final_tensor = []
@@ -337,29 +335,25 @@ def index(
337335 current_dim = dim_tensor_list [i ]
338336 concat_final_tensor .append (current_dim )
339337
340- reshape_output = []
341- if (is_numpy ):
342- reshape_output = transpose_tensor .reshape (concat_final_tensor )
343- else :
344- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345- set_layer_name (
346- concat_final_shape_layer ,
347- target ,
348- name + "_index_continuous_concat_final_shape_layer" ,
349- source_ir ,
350- )
351- concat_final_tensor = concat_final_shape_layer .get_output (0 )
352-
353- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354- # check this
355- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356- set_layer_name (
357- unfold_advanced_shuffle_layer ,
358- target ,
359- name + "_unfold_advanced_index" ,
360- source_ir ,
361- )
362- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
338+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
339+ set_layer_name (
340+ concat_final_shape_layer ,
341+ target ,
342+ name + "_index_continuous_concat_final_shape_layer" ,
343+ source_ir ,
344+ )
345+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
346+
347+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
348+ # check this
349+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
350+ set_layer_name (
351+ unfold_advanced_shuffle_layer ,
352+ target ,
353+ name + "_unfold_advanced_index" ,
354+ source_ir ,
355+ )
356+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
363357
364358 else :
365359 _LOGGER .debug (f"The indices are not continuous in this case" )
@@ -370,27 +364,23 @@ def index(
370364 curr_dim = dim_tensor_list [i ]
371365 concat_final_tensor .append (curr_dim )
372366
373- reshape_output = None
374- if (is_numpy ):
375- reshape_output = gather_out .reshape (concat_final_tensor )
376- else :
377- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378- set_layer_name (
379- concat_final_shape_layer ,
380- target ,
381- name + "_index_non_continuous_concat_final_shape_layer" ,
382- source_ir ,
383- )
384- concat_final_tensor = concat_final_shape_layer .get_output (0 )
385-
386- reshape_layer = ctx .net .add_shuffle (gather_out )
387- reshape_layer .set_input (1 , concat_final_tensor )
388- set_layer_name (
389- reshape_layer ,
390- target ,
391- name + "_index_non_continuous_shuffle_final_shape_layer" ,
392- source_ir ,
393- )
394- reshape_output = reshape_layer .get_output (0 )
367+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
368+ set_layer_name (
369+ concat_final_shape_layer ,
370+ target ,
371+ name + "_index_non_continuous_concat_final_shape_layer" ,
372+ source_ir ,
373+ )
374+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
375+
376+ reshape_layer = ctx .net .add_shuffle (gather_out )
377+ reshape_layer .set_input (1 , concat_final_tensor )
378+ set_layer_name (
379+ reshape_layer ,
380+ target ,
381+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
382+ source_ir ,
383+ )
384+ reshape_output = reshape_layer .get_output (0 )
395385
396386 return reshape_output
0 commit comments