@@ -88,9 +88,16 @@ def index(
8888 # _LOGGER.debug(f"The index shape is {index.shape}")
8989 # check if the input is dynamic
9090 dynamic_shape = has_dynamic_shape (input .shape )
91- #is_numpy is a flag to specify if input isa numpy
92- is_numpy = False
93-
91+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
92+ # If any is not this flag will be set to False
93+ is_numpy = True
94+ _LOGGER .debug (f"Checking for the is_numpy flag" )
95+ for i , ind in enumerate (index ):
96+ if ind is None :
97+ continue
98+ if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
99+ is_numpy = False
100+ break
94101 # here we need to check if all the index are broadcastable
95102 # if no, then we need to broadcast
96103 last_index = None
@@ -101,7 +108,7 @@ def index(
101108 # torch.nn.parameter.Parameter=> numpy array
102109 # numpy array is kept as numpy
103110 # other cases are kept as TRTTensor
104- if ( isinstance ( ind , torch . Tensor ) or ( ind , np . ndarray )) :
111+ if is_numpy :
105112 ind = to_numpy (ind )
106113 is_numpy = True
107114 else :
@@ -119,8 +126,9 @@ def index(
119126 set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
120127 return identity_layer .get_output (0 )
121128 elif len (tensor_indices ) == 1 :
122- # This case works
123- indices_tensor = tensor_indices [0 ]
129+ indices_tensor = get_trt_tensor (
130+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
131+ )
124132 index = adv_indx_indices [0 ]
125133 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
126134 gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -136,15 +144,15 @@ def index(
136144 rank = len (input_shape )
137145 adv_indx_count = len (adv_indx_indices )
138146 dim_tensor_list = []
147+ dim_list = []
139148
140149 for i in range (rank ):
141150 dim = input_shape [i ]
142151 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
143152 # dim_tensor_list is a list of tensors or numpy
144- if (is_numpy ):
145- dim_tensor_list .append (dim )
146- else :
147- dim_tensor_list .append (dim_tensor )
153+ if is_numpy :
154+ dim_list .append (dim )
155+ dim_tensor_list .append (dim_tensor )
148156
149157 # for cases like
150158 # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -163,13 +171,9 @@ def index(
163171 new_order .append (i )
164172 _LOGGER .debug (f"The new transpose order is { new_order } " )
165173
166- transpose_tensor = None
167- if (is_numpy ):
168- transpose_tensor = input [new_order ]
169- else :
170- transpose_layer .second_transpose = tuple (new_order )
171- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172- transpose_tensor = transpose_layer .get_output (0 )
174+ transpose_layer .second_transpose = tuple (new_order )
175+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
176+ transpose_tensor = transpose_layer .get_output (0 )
173177
174178 # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175179 # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +186,34 @@ def index(
182186 for i in range (adv_indx_count , rank ):
183187 mult_d1 = mult_d1 * transpose_tensor_shape [i ]
184188
185- flatten_tensor = None
186- if (is_numpy ):
187- flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188- else :
189- concat_tensor_layer = ctx .net .add_concatenation (
190- [
191- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193- ]
194- )
195- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196- concat_tensor = concat_tensor_layer .get_output (0 )
189+ concat_tensor_layer = ctx .net .add_concatenation (
190+ [
191+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193+ ]
194+ )
195+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196+ concat_tensor = concat_tensor_layer .get_output (0 )
197197
198- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199- reshape_layer .set_input (1 , concat_tensor )
200- flatten_tensor = reshape_layer .get_output (0 )
198+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199+ reshape_layer .set_input (1 , concat_tensor )
200+ flatten_tensor = reshape_layer .get_output (0 )
201201
202202 _LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
203203
204204 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205205 # // j dimension of input x.
206- if ( is_numpy ) :
207- multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
206+ if is_numpy :
207+ multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
208208 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209209 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210210 adv_index = multiplier * tensor_indices [i ]
211211 cum_adv_index = cum_adv_index + adv_index
212- multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
212+ multiplier = multiplier * dim_list [adv_indx_indices [i ]]
213+ cum_adv_index = get_trt_tensor (
214+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
215+ )
213216 else :
214-
215217 multiplier = get_trt_tensor (
216218 ctx ,
217219 dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
@@ -269,36 +271,29 @@ def index(
269271 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
270272 ):
271273 _LOGGER .debug (f"The indices are continuous in this case" )
272- if (is_numpy ):
273- concat_tensor_reshape .append (- 1 )
274- else :
275- concat_tensor_reshape .append (
276- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277- )
274+ concat_tensor_reshape .append (
275+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
276+ )
278277 for i in range (0 , rank ):
279278 if i not in adv_indx_indices :
280279 curr_dim = dim_tensor_list [i ]
281280 concat_tensor_reshape .append (curr_dim )
282281
283- unfold_tensor = None
284- if (is_numpy ):
285- unfold_tensor = gather_out .reshape (concat_tensor )
286- else :
287- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288- set_layer_name (
289- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290- )
291- concat_tensor = concat_tensor_layer .get_output (0 )
282+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
283+ set_layer_name (
284+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
285+ )
286+ concat_tensor = concat_tensor_layer .get_output (0 )
292287
293- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294- regular_index_shuffle_layer .set_input (1 , concat_tensor )
295- set_layer_name (
296- regular_index_shuffle_layer ,
297- target ,
298- name + "_index_regular_index" ,
299- source_ir ,
300- )
301- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
288+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
289+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
290+ set_layer_name (
291+ regular_index_shuffle_layer ,
292+ target ,
293+ name + "_index_regular_index" ,
294+ source_ir ,
295+ )
296+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
302297 _LOGGER .debug (f"The tensor is unfolded now" )
303298 _LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
304299
@@ -312,18 +307,14 @@ def index(
312307 new_order .append (i )
313308 _LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
314309
315- transpose_tensor = None
316- if (is_numpy ):
317- transpose_tensor = unfold_tensor [new_order ]
318- else :
319- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320- set_layer_name (
321- transpose_advanced_shuffle_layer ,
322- target ,
323- name + "_index_advanced_shuffle_transpose" ,
324- source_ir ,
325- )
326- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
310+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
311+ set_layer_name (
312+ transpose_advanced_shuffle_layer ,
313+ target ,
314+ name + "_index_advanced_shuffle_transpose" ,
315+ source_ir ,
316+ )
317+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
327318
328319 # unfold advanced layer
329320 concat_final_tensor = []
@@ -337,29 +328,25 @@ def index(
337328 current_dim = dim_tensor_list [i ]
338329 concat_final_tensor .append (current_dim )
339330
340- reshape_output = []
341- if (is_numpy ):
342- reshape_output = transpose_tensor .reshape (concat_final_tensor )
343- else :
344- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345- set_layer_name (
346- concat_final_shape_layer ,
347- target ,
348- name + "_index_continuous_concat_final_shape_layer" ,
349- source_ir ,
350- )
351- concat_final_tensor = concat_final_shape_layer .get_output (0 )
352-
353- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354- # check this
355- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356- set_layer_name (
357- unfold_advanced_shuffle_layer ,
358- target ,
359- name + "_unfold_advanced_index" ,
360- source_ir ,
361- )
362- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
331+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
332+ set_layer_name (
333+ concat_final_shape_layer ,
334+ target ,
335+ name + "_index_continuous_concat_final_shape_layer" ,
336+ source_ir ,
337+ )
338+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
339+
340+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
341+ # check this
342+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
343+ set_layer_name (
344+ unfold_advanced_shuffle_layer ,
345+ target ,
346+ name + "_unfold_advanced_index" ,
347+ source_ir ,
348+ )
349+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
363350
364351 else :
365352 _LOGGER .debug (f"The indices are not continuous in this case" )
@@ -370,27 +357,23 @@ def index(
370357 curr_dim = dim_tensor_list [i ]
371358 concat_final_tensor .append (curr_dim )
372359
373- reshape_output = None
374- if (is_numpy ):
375- reshape_output = gather_out .reshape (concat_final_tensor )
376- else :
377- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378- set_layer_name (
379- concat_final_shape_layer ,
380- target ,
381- name + "_index_non_continuous_concat_final_shape_layer" ,
382- source_ir ,
383- )
384- concat_final_tensor = concat_final_shape_layer .get_output (0 )
385-
386- reshape_layer = ctx .net .add_shuffle (gather_out )
387- reshape_layer .set_input (1 , concat_final_tensor )
388- set_layer_name (
389- reshape_layer ,
390- target ,
391- name + "_index_non_continuous_shuffle_final_shape_layer" ,
392- source_ir ,
393- )
394- reshape_output = reshape_layer .get_output (0 )
360+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
361+ set_layer_name (
362+ concat_final_shape_layer ,
363+ target ,
364+ name + "_index_non_continuous_concat_final_shape_layer" ,
365+ source_ir ,
366+ )
367+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
368+
369+ reshape_layer = ctx .net .add_shuffle (gather_out )
370+ reshape_layer .set_input (1 , concat_final_tensor )
371+ set_layer_name (
372+ reshape_layer ,
373+ target ,
374+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
375+ source_ir ,
376+ )
377+ reshape_output = reshape_layer .get_output (0 )
395378
396379 return reshape_output
0 commit comments