@@ -81,30 +81,21 @@ def index(
8181 source_ir : Optional [SourceIR ],
8282 name : str ,
8383 input : TRTTensor ,
84- index : Union [
85- TRTTensor ,
86- Sequence [TRTTensor ],
87- np .ndarray ,
88- Sequence [np .ndarray ],
89- torch .Tensor ,
90- Sequence [torch .Tensor ],
91- ],
84+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
9285) -> TRTTensor :
9386 adv_indx_indices = []
9487 tensor_indices = []
95- # _LOGGER.debug(f"The index shape is {index.shape}")
9688 # check if the input is dynamic
9789 dynamic_shape = has_dynamic_shape (input .shape )
9890 # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9991 # If any is not this flag will be set to False
10092 is_numpy = True
101- _LOGGER .debug (f"Checking for the is_numpy flag" )
102- for i , ind in enumerate (index ):
103- if ind is None :
104- continue
105- if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
106- is_numpy = False
107- break
93+ _LOGGER .debug (
94+ f"Determining whether aten.index constant-index optimization can be invoked"
95+ )
96+ is_numpy = all (
97+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
98+ )
10899 # here we need to check if all the index are broadcastable
109100 # if no, then we need to broadcast
110101 last_index = None
@@ -117,7 +108,6 @@ def index(
117108 # other cases are kept as TRTTensor
118109 if is_numpy :
119110 ind = to_numpy (ind )
120- is_numpy = True
121111 else :
122112 ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
123113 if last_index is not None :
@@ -156,9 +146,7 @@ def index(
156146 for i in range (rank ):
157147 dim = input_shape [i ]
158148 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
159- # dim_tensor_list is a list of tensors or numpy
160- if is_numpy :
161- dim_list .append (dim )
149+ # dim_tensor_list is a list of tensors
162150 dim_tensor_list .append (dim_tensor )
163151
164152 # for cases like
@@ -211,12 +199,12 @@ def index(
211199 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212200 # // j dimension of input x.
213201 if is_numpy :
214- multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
202+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
215203 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216204 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217205 adv_index = multiplier * tensor_indices [i ]
218206 cum_adv_index = cum_adv_index + adv_index
219- multiplier = multiplier * dim_list [adv_indx_indices [i ]]
207+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
220208 cum_adv_index = get_trt_tensor (
221209 ctx , cum_adv_index , name + f"_index_sum_intermediate"
222210 )
0 commit comments