@@ -88,9 +88,16 @@ def index(
8888 # _LOGGER.debug(f"The index shape is {index.shape}")
8989 # check if the input is dynamic
9090 dynamic_shape = has_dynamic_shape (input .shape )
91- #is_numpy is a flag to specify if input isa numpy
92- is_numpy = False
93-
91+ # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
92+ # If any is not this flag will be set to False
93+ is_numpy = True
94+ _LOGGER .debug (f"Checking for the is_numpy flag" )
95+ for i , ind in enumerate (index ):
96+ if ind is None :
97+ continue
98+ if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
99+ is_numpy = False
100+ break
94101 # here we need to check if all the index are broadcastable
95102 # if no, then we need to broadcast
96103 last_index = None
@@ -101,7 +108,7 @@ def index(
101108 # torch.nn.parameter.Parameter=> numpy array
102109 # numpy array is kept as numpy
103110 # other cases are kept as TRTTensor
104- if ( isinstance ( ind , torch . Tensor ) or ( ind , np . ndarray )) :
111+ if is_numpy :
105112 ind = to_numpy (ind )
106113 is_numpy = True
107114 else :
@@ -119,8 +126,9 @@ def index(
119126 set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
120127 return identity_layer .get_output (0 )
121128 elif len (tensor_indices ) == 1 :
122- # This case works
123- indices_tensor = tensor_indices [0 ]
129+ indices_tensor = get_trt_tensor (
130+ ctx , tensor_indices [0 ], name + f"_parameter_to_fp32_tensor"
131+ )
124132 index = adv_indx_indices [0 ]
125133 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
126134 gather_layer = ctx .net .add_gather (input , indices_tensor , index )
@@ -136,15 +144,15 @@ def index(
136144 rank = len (input_shape )
137145 adv_indx_count = len (adv_indx_indices )
138146 dim_tensor_list = []
147+ dim_list = []
139148
140149 for i in range (rank ):
141150 dim = input_shape [i ]
142151 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
143152 # dim_tensor_list is a list of tensors or numpy
144- if (is_numpy ):
145- dim_tensor_list .append (dim )
146- else :
147- dim_tensor_list .append (dim_tensor )
153+ if is_numpy :
154+ dim_list .append (dim )
155+ dim_tensor_list .append (dim_tensor )
148156
149157 # for cases like
150158 # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
@@ -164,12 +172,9 @@ def index(
164172 _LOGGER .debug (f"The new transpose order is { new_order } " )
165173
166174 transpose_tensor = None
167- if (is_numpy ):
168- transpose_tensor = input [new_order ]
169- else :
170- transpose_layer .second_transpose = tuple (new_order )
171- set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
172- transpose_tensor = transpose_layer .get_output (0 )
175+ transpose_layer .second_transpose = tuple (new_order )
176+ set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
177+ transpose_tensor = transpose_layer .get_output (0 )
173178
174179 # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n]
175180 # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor)
@@ -182,36 +187,34 @@ def index(
182187 for i in range (adv_indx_count , rank ):
183188 mult_d1 = mult_d1 * transpose_tensor_shape [i ]
184189
185- flatten_tensor = None
186- if (is_numpy ):
187- flatten_tensor = transpose_tensor .reshape (mult_d0 , mult_d1 )
188- else :
189- concat_tensor_layer = ctx .net .add_concatenation (
190- [
191- get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
192- get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
193- ]
194- )
195- set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
196- concat_tensor = concat_tensor_layer .get_output (0 )
190+ concat_tensor_layer = ctx .net .add_concatenation (
191+ [
192+ get_trt_tensor (ctx , mult_d0 , name + "_d0_shape" ),
193+ get_trt_tensor (ctx , mult_d1 , name + "_d1_shape" ),
194+ ]
195+ )
196+ set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
197+ concat_tensor = concat_tensor_layer .get_output (0 )
197198
198- reshape_layer = ctx .net .add_shuffle (transpose_tensor )
199- reshape_layer .set_input (1 , concat_tensor )
200- flatten_tensor = reshape_layer .get_output (0 )
199+ reshape_layer = ctx .net .add_shuffle (transpose_tensor )
200+ reshape_layer .set_input (1 , concat_tensor )
201+ flatten_tensor = reshape_layer .get_output (0 )
201202
202203 _LOGGER .debug (f"The flatten tensor shape is { flatten_tensor .shape } " )
203204
204205 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
205206 # // j dimension of input x.
206- if ( is_numpy ) :
207- multiplier = dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]]
207+ if is_numpy :
208+ multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
208209 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
209210 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
210211 adv_index = multiplier * tensor_indices [i ]
211212 cum_adv_index = cum_adv_index + adv_index
212- multiplier = multiplier * dim_tensor_list [adv_indx_indices [i ]]
213+ multiplier = multiplier * dim_list [adv_indx_indices [i ]]
214+ cum_adv_index = get_trt_tensor (
215+ ctx , cum_adv_index , name + f"_index_sum_intermediate"
216+ )
213217 else :
214-
215218 multiplier = get_trt_tensor (
216219 ctx ,
217220 dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
@@ -269,36 +272,31 @@ def index(
269272 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
270273 ):
271274 _LOGGER .debug (f"The indices are continuous in this case" )
272- if (is_numpy ):
273- concat_tensor_reshape .append (- 1 )
274- else :
275- concat_tensor_reshape .append (
276- get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277- )
275+ concat_tensor_reshape .append (
276+ get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
277+ )
278278 for i in range (0 , rank ):
279279 if i not in adv_indx_indices :
280280 curr_dim = dim_tensor_list [i ]
281281 concat_tensor_reshape .append (curr_dim )
282282
283283 unfold_tensor = None
284- if (is_numpy ):
285- unfold_tensor = gather_out .reshape (concat_tensor )
286- else :
287- concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
288- set_layer_name (
289- concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
290- )
291- concat_tensor = concat_tensor_layer .get_output (0 )
292284
293- regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
294- regular_index_shuffle_layer .set_input (1 , concat_tensor )
295- set_layer_name (
296- regular_index_shuffle_layer ,
297- target ,
298- name + "_index_regular_index" ,
299- source_ir ,
300- )
301- unfold_tensor = regular_index_shuffle_layer .get_output (0 )
285+ concat_tensor_layer = ctx .net .add_concatenation (concat_tensor_reshape )
286+ set_layer_name (
287+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
288+ )
289+ concat_tensor = concat_tensor_layer .get_output (0 )
290+
291+ regular_index_shuffle_layer = ctx .net .add_shuffle (gather_out )
292+ regular_index_shuffle_layer .set_input (1 , concat_tensor )
293+ set_layer_name (
294+ regular_index_shuffle_layer ,
295+ target ,
296+ name + "_index_regular_index" ,
297+ source_ir ,
298+ )
299+ unfold_tensor = regular_index_shuffle_layer .get_output (0 )
302300 _LOGGER .debug (f"The tensor is unfolded now" )
303301 _LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
304302
@@ -313,17 +311,15 @@ def index(
313311 _LOGGER .debug (f"Transposing the indices to correct position { new_order } " )
314312
315313 transpose_tensor = None
316- if (is_numpy ):
317- transpose_tensor = unfold_tensor [new_order ]
318- else :
319- transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
320- set_layer_name (
321- transpose_advanced_shuffle_layer ,
322- target ,
323- name + "_index_advanced_shuffle_transpose" ,
324- source_ir ,
325- )
326- transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
314+
315+ transpose_advanced_shuffle_layer .second_transpose = tuple (new_order )
316+ set_layer_name (
317+ transpose_advanced_shuffle_layer ,
318+ target ,
319+ name + "_index_advanced_shuffle_transpose" ,
320+ source_ir ,
321+ )
322+ transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
327323
328324 # unfold advanced layer
329325 concat_final_tensor = []
@@ -338,28 +334,25 @@ def index(
338334 concat_final_tensor .append (current_dim )
339335
340336 reshape_output = []
341- if (is_numpy ):
342- reshape_output = transpose_tensor .reshape (concat_final_tensor )
343- else :
344- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
345- set_layer_name (
346- concat_final_shape_layer ,
347- target ,
348- name + "_index_continuous_concat_final_shape_layer" ,
349- source_ir ,
350- )
351- concat_final_tensor = concat_final_shape_layer .get_output (0 )
352-
353- unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
354- # check this
355- unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
356- set_layer_name (
357- unfold_advanced_shuffle_layer ,
358- target ,
359- name + "_unfold_advanced_index" ,
360- source_ir ,
361- )
362- reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
337+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
338+ set_layer_name (
339+ concat_final_shape_layer ,
340+ target ,
341+ name + "_index_continuous_concat_final_shape_layer" ,
342+ source_ir ,
343+ )
344+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
345+
346+ unfold_advanced_shuffle_layer = ctx .net .add_shuffle (transpose_tensor )
347+ # check this
348+ unfold_advanced_shuffle_layer .set_input (1 , concat_final_tensor )
349+ set_layer_name (
350+ unfold_advanced_shuffle_layer ,
351+ target ,
352+ name + "_unfold_advanced_index" ,
353+ source_ir ,
354+ )
355+ reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
363356
364357 else :
365358 _LOGGER .debug (f"The indices are not continuous in this case" )
@@ -371,26 +364,24 @@ def index(
371364 concat_final_tensor .append (curr_dim )
372365
373366 reshape_output = None
374- if (is_numpy ):
375- reshape_output = gather_out .reshape (concat_final_tensor )
376- else :
377- concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
378- set_layer_name (
379- concat_final_shape_layer ,
380- target ,
381- name + "_index_non_continuous_concat_final_shape_layer" ,
382- source_ir ,
383- )
384- concat_final_tensor = concat_final_shape_layer .get_output (0 )
385367
386- reshape_layer = ctx .net .add_shuffle (gather_out )
387- reshape_layer .set_input (1 , concat_final_tensor )
388- set_layer_name (
389- reshape_layer ,
390- target ,
391- name + "_index_non_continuous_shuffle_final_shape_layer" ,
392- source_ir ,
393- )
394- reshape_output = reshape_layer .get_output (0 )
368+ concat_final_shape_layer = ctx .net .add_concatenation (concat_final_tensor )
369+ set_layer_name (
370+ concat_final_shape_layer ,
371+ target ,
372+ name + "_index_non_continuous_concat_final_shape_layer" ,
373+ source_ir ,
374+ )
375+ concat_final_tensor = concat_final_shape_layer .get_output (0 )
376+
377+ reshape_layer = ctx .net .add_shuffle (gather_out )
378+ reshape_layer .set_input (1 , concat_final_tensor )
379+ set_layer_name (
380+ reshape_layer ,
381+ target ,
382+ name + "_index_non_continuous_shuffle_final_shape_layer" ,
383+ source_ir ,
384+ )
385+ reshape_output = reshape_layer .get_output (0 )
395386
396387 return reshape_output
0 commit comments