@@ -89,16 +89,16 @@ def index(
8989 # if no, then we need to broadcast
9090
9191 last_index = None
92- broadcast_shape_len = 0
9392 for i , ind in enumerate (index ):
9493 if ind is not None :
9594 _LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
9695 adv_indx_indices .append (i )
9796 # torch.nn.parameter.Parameter=> torch.Tensor
98- ind = get_trt_tensor (network , ind , f"parameter_to_fp32_tensor_ { i } " )
97+ ind = get_trt_tensor (network , ind , name + f"_parameter_to_fp32_tensor_ { i } " )
9998 if last_index is not None :
100- if not (broadcastable (ind , last_index )):
101- assert "The indices should be broadcastable"
99+ assert broadcastable (
100+ ind , last_index
101+ ), "The indices should be broadcastable!"
102102 last_index = ind
103103 tensor_indices .append (ind )
104104
@@ -128,7 +128,7 @@ def index(
128128
129129 for i in range (rank ):
130130 dim = input_shape [i ]
131- dim_tensor = get_trt_tensor (network , dim , f"individual_dim_ { i } " )
131+ dim_tensor = get_trt_tensor (network , dim , name + f"_individual_dim_ { i } " )
132132 # dim_tensor_list is a list of tensors
133133 dim_tensor_list .append (dim_tensor )
134134
@@ -165,8 +165,8 @@ def index(
165165
166166 concat_tensor_layer = network .add_concatenation (
167167 [
168- get_trt_tensor (network , mult_d0 , "d0_shape " ),
169- get_trt_tensor (network , mult_d1 , "d1_shape " ),
168+ get_trt_tensor (network , mult_d0 , name + "_d0_shape " ),
169+ get_trt_tensor (network , mult_d1 , name + "_d1_shape " ),
170170 ]
171171 )
172172 set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
@@ -181,15 +181,17 @@ def index(
181181 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182182 # // j dimension of input x.
183183 multiplier = get_trt_tensor (
184- network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
184+ network ,
185+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
186+ name + "_dim_last" ,
185187 )
186188 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
187189 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
188190 adv_index = convert_binary_elementwise (
189191 network ,
190192 target ,
191193 source_ir ,
192- name + "index_intermediate " ,
194+ name + f"_index_intermediate_ { i } " ,
193195 trt .ElementWiseOperation .PROD ,
194196 multiplier ,
195197 tensor_indices [i ],
@@ -198,7 +200,7 @@ def index(
198200 network ,
199201 target ,
200202 source_ir ,
201- name + "index_sum_intermediate " ,
203+ name + f"_index_sum_intermediate_ { i } " ,
202204 trt .ElementWiseOperation .SUM ,
203205 cum_adv_index ,
204206 adv_index ,
@@ -207,7 +209,7 @@ def index(
207209 network ,
208210 target ,
209211 source_ir ,
210- name + "index_intermediate " ,
212+ name + f"_index_intermediate_xj_ { i } " ,
211213 trt .ElementWiseOperation .PROD ,
212214 multiplier ,
213215 dim_tensor_list [adv_indx_indices [i ]],
@@ -235,7 +237,9 @@ def index(
235237 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
236238 ):
237239 _LOGGER .debug (f"The indices are continuous in this case" )
238- concat_tensor_reshape .append (get_trt_tensor (network , - 1 , "dynamic_concat" ))
240+ concat_tensor_reshape .append (
241+ get_trt_tensor (network , - 1 , name + "_dynamic_concat" )
242+ )
239243 for i in range (0 , rank ):
240244 if i not in adv_indx_indices :
241245 curr_dim = dim_tensor_list [i ]
@@ -294,7 +298,7 @@ def index(
294298 set_layer_name (
295299 concat_final_shape_layer ,
296300 target ,
297- name + "_index_concat_final_shape_layer " ,
301+ name + "_index_continuous_concat_final_shape_layer " ,
298302 source_ir ,
299303 )
300304 concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -311,17 +315,19 @@ def index(
311315 reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
312316
313317 else :
314- concat_tensor = []
318+ _LOGGER .debug (f"The indices are not continuous in this case" )
319+ concat_final_tensor = []
320+ concat_final_tensor .append (cum_adv_index_shape_tensor )
315321 for i in range (0 , rank ):
316322 if i not in adv_indx_indices :
317323 curr_dim = dim_tensor_list [i ]
318- concat_tensor .append (curr_dim )
324+ concat_final_tensor .append (curr_dim )
319325
320- concat_layer = network .add_concatenation (concat_tensor )
326+ concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
321327 set_layer_name (
322- concat_layer ,
328+ concat_final_shape_layer ,
323329 target ,
324- name + "_index_concat_final_shape_layer " ,
330+ name + "_index_non_continuous_concat_final_shape_layer " ,
325331 source_ir ,
326332 )
327333 concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -331,7 +337,7 @@ def index(
331337 set_layer_name (
332338 reshape_layer ,
333339 target ,
334- name + "_index_shuffle_final_shape_layer " ,
340+ name + "_index_non_continuous_shuffle_final_shape_layer " ,
335341 source_ir ,
336342 )
337343 reshape_output = reshape_layer .get_output (0 )
0 commit comments