-
Notifications
You must be signed in to change notification settings - Fork 370
Aten::Index converter #2277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aten::Index converter #2277
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-08-29 21:03:31.334736+00:00
@@ -69,28 +69,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -99,7 +99,5 @@
adv_indx_count = len(adv_indx_indices)
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
return input_shape_tensor.get_output(0)
-
-
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-08-29 21:00:52.894681+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-08-29 21:03:31.387359+00:00
@@ -169,11 +169,11 @@
)
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+ network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:71ca151 to
302b962
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2023-09-01 18:00:04.480860+00:00
@@ -169,11 +169,11 @@
)
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
def aten_ops_index(
-network: TRTNetwork,
+ network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-01 17:57:45.620889+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-01 18:00:04.530477+00:00
@@ -71,28 +71,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -102,24 +102,26 @@
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
dim_tensor_list = []
for i in range(rank):
- #check this
- dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
- set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+ # check this
+ dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+ set_layer_name(
+ input_shape_layer, target, name + "_index_gather_rank", source_ir
+ )
dim_tensor = dim_tensor_layer.get_output(0)
dim_tensor_list.append(dim_tensor)
- #for cases like
- #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
- #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
- #for ":"
- #Examples: x.shape = (10,20,30,40,50)
- #ind_1, ind_2 broadcasted to (2,3,4)
- #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
- #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+ # for cases like
+ # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+ # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+ # for ":"
+ # Examples: x.shape = (10,20,30,40,50)
+ # ind_1, ind_2 broadcasted to (2,3,4)
+ # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+ # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
transpose_layer = network.add_shuffle(input)
new_order = []
for i in range(adv_indx_count):
new_order.append(adv_indx_indices[i])
for i in range(rank):
@@ -130,166 +132,194 @@
permute_order(new_order)
transpose_layer.set_second_transpose(permute_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
- #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+ # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
transpose_tensor_shape = network.add_shape(transpose_tensor)
d0 = 1
d0 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+ )
d0_gather = gather_layer.get_output(0)
mult_d0 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatOne_shape",
- trt.ElementWisePROD,
- mult_d0,
- d0_gather,
- )
-
+ network,
+ target,
+ source_ir,
+ name + "index_concatOne_shape",
+ trt.ElementWisePROD,
+ mult_d0,
+ d0_gather,
+ )
+
d1 = 1
d1 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count, rank):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+ )
d1_gather = gather_layer.get_output(0)
mult_d1 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatTwo_shape",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_concatTwo_shape",
+ trt.ElementWisePROD,
mult_d1,
d1_gather,
)
concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
concat_tensor = concat_tensor_layer.get_output(0)
reshape_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)
- #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
- #// j dimension of input x.
- multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+ # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
+ # // j dimension of input x.
+ multiplier = get_trt_tensor(
+ network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+ )
cum_adv_index = tensor_indices[adv_indx_count - 1]
- for i in range(adv_indx_count-2, 0):
+ for i in range(adv_indx_count - 2, 0):
adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_sum_intermediate",
- trt.ElementWiseSUM,
+ network,
+ target,
+ source_ir,
+ name + "index_sum_intermediate",
+ trt.ElementWiseSUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
dim_tensor_list[adv_indx_count[i]],
)
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
- set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+ set_layer_name(
+ gather_layer_element, target, name + "_index_gather_element", source_ir
+ )
gather_out = gather_layer.get_output(0)
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
- #check if all advanced indices are consecutive
+ # check if all advanced indices are consecutive
concat_tensor_reshape = []
- if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
- #concat_tensor_reshape_initial = -1
- #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+ if (
+ adv_indx_count
+ == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+ ):
+ # concat_tensor_reshape_initial = -1
+ # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
concat_tensor_reshape.append(-1)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
-
+
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
- set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+ set_layer_name(
+ concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+ )
concat_tensor = concat_tensor_layer.get_output(0)
regular_index_shuffle_layer = network.add_shuffle(gather_out)
- set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+ set_layer_name(
+ regular_index_shuffle_layer,
+ target,
+ name + "_index_regular_index",
+ source_ir,
+ )
unfold_tensor = regular_index_shuffle_layer.get_output(0)
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
new_order = []
- for i in range(1, adv_indx_count[0]+1):
+ for i in range(1, adv_indx_count[0] + 1):
new_order.append(i)
new_order.append(0)
- for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+ for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
new_order.append(i)
permute_order = trt.Permutation()
permute_order(new_order)
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
- set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+ set_layer_name(
+ transpose_advanced_shuffle_layer,
+ target,
+ name + "_index_advanced_shuffle_transpose",
+ source_ir,
+ )
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
- #unfold advanced layer
+ # unfold advanced layer
concat_final_tensor = []
for i in range(0, adv_indx_indices[0]):
current_dim = dim_tensor_list[i]
concat_final_tensor.push_back(curr_dim)
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
for i in range(adv_indx_indices[0], rank):
- if(i not in (adv_indx_indices)):
+ if i not in (adv_indx_indices):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)
-
+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
- set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_final_shape_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_final_tensor)
reshape_output = reshape_layer.get_output(0)
-
+
else:
- concat_tensor= []
+ concat_tensor = []
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor.append(curr_dim)
-
+
concat_layer = network.add_concatenation(concat_tensor)
- set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
reshape_layer = network.add_shuffle(gather_out)
reshape_layer.setInput(1, concat_final_tensor)
- set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+ set_layer_name(
+ reshape_layer,
+ target,
+ name + "_index_shuffle_final_shape_layer",
+ source_ir,
+ )
reshape_output = reshape_layer.get_output(0)
return reshape_output
-
-
-
-
-
-
-
-
-
- 4f2a738 to
42798cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-07 03:07:48.113307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-07 03:10:32.769440+00:00
@@ -71,28 +71,28 @@
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- index: Union[TRTTensor, Sequence[TRTTensor]]
+ index: Union[TRTTensor, Sequence[TRTTensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
for i in len(index):
ind = index[i]
- #FIXME: check if the datatype for the indices needs to be casted to INT32
- #TRTInterpretor should take care
+ # FIXME: check if the datatype for the indices needs to be casted to INT32
+ # TRTInterpretor should take care
adv_indx_indices.append(i)
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt.int32)
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
- elif (len(tensor_indices) == 1):
+ elif len(tensor_indices) == 1:
indices_tensor = tensor_indices[0]
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
else:
@@ -102,24 +102,26 @@
input_shape_layer = network.add_shape(input)
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
input_shape_tensor = input_shape_layer.get_output(0)
dim_tensor_list = []
for i in range(rank):
- #check this
- dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
- set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
+ # check this
+ dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
+ set_layer_name(
+ input_shape_layer, target, name + "_index_gather_rank", source_ir
+ )
dim_tensor = dim_tensor_layer.get_output(0)
dim_tensor_list.append(dim_tensor)
- #for cases like
- #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
- #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
- #for ":"
- #Examples: x.shape = (10,20,30,40,50)
- #ind_1, ind_2 broadcasted to (2,3,4)
- #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
- #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
+ # for cases like
+ # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
+ # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
+ # for ":"
+ # Examples: x.shape = (10,20,30,40,50)
+ # ind_1, ind_2 broadcasted to (2,3,4)
+ # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
+ # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
transpose_layer = network.add_shuffle(input)
new_order = []
for i in range(adv_indx_count):
new_order.append(adv_indx_indices[i])
for i in range(rank):
@@ -130,166 +132,194 @@
permute_order(new_order)
transpose_layer.set_second_transpose(permute_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
- #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
+ # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
transpose_tensor_shape = network.add_shape(transpose_tensor)
d0 = 1
d0 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
+ )
d0_gather = gather_layer.get_output(0)
mult_d0 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatOne_shape",
- trt.ElementWisePROD,
- mult_d0,
- d0_gather,
- )
-
+ network,
+ target,
+ source_ir,
+ name + "index_concatOne_shape",
+ trt.ElementWisePROD,
+ mult_d0,
+ d0_gather,
+ )
+
d1 = 1
d1 = get_trt_tensor(network, d0, "d0_initial")
for i in range(adv_indx_count, rank):
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
- set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
+ set_layer_name(
+ dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
+ )
d1_gather = gather_layer.get_output(0)
mult_d1 = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_concatTwo_shape",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_concatTwo_shape",
+ trt.ElementWisePROD,
mult_d1,
d1_gather,
)
concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
concat_tensor = concat_tensor_layer.get_output(0)
reshape_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)
- #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
- #// j dimension of input x.
- multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
+ # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
+ # // j dimension of input x.
+ multiplier = get_trt_tensor(
+ network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
+ )
cum_adv_index = tensor_indices[adv_indx_count - 1]
- for i in range(adv_indx_count-2, 0):
+ for i in range(adv_indx_count - 2, 0):
adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_sum_intermediate",
- trt.ElementWiseSUM,
+ network,
+ target,
+ source_ir,
+ name + "index_sum_intermediate",
+ trt.ElementWiseSUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
- network,
- target,
- source_ir,
- name + "index_intermediate",
- trt.ElementWisePROD,
+ network,
+ target,
+ source_ir,
+ name + "index_intermediate",
+ trt.ElementWisePROD,
multiplier,
dim_tensor_list[adv_indx_count[i]],
)
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
- set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
+ set_layer_name(
+ gather_layer_element, target, name + "_index_gather_element", source_ir
+ )
gather_out = gather_layer.get_output(0)
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
- #check if all advanced indices are consecutive
+ # check if all advanced indices are consecutive
concat_tensor_reshape = []
- if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
- #concat_tensor_reshape_initial = -1
- #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
+ if (
+ adv_indx_count
+ == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
+ ):
+ # concat_tensor_reshape_initial = -1
+ # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
concat_tensor_reshape.append(-1)
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
-
+
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
- set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
+ set_layer_name(
+ concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
+ )
concat_tensor = concat_tensor_layer.get_output(0)
regular_index_shuffle_layer = network.add_shuffle(gather_out)
- set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
+ set_layer_name(
+ regular_index_shuffle_layer,
+ target,
+ name + "_index_regular_index",
+ source_ir,
+ )
unfold_tensor = regular_index_shuffle_layer.get_output(0)
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
new_order = []
- for i in range(1, adv_indx_count[0]+1):
+ for i in range(1, adv_indx_count[0] + 1):
new_order.append(i)
new_order.append(0)
- for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
+ for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
new_order.append(i)
permute_order = trt.Permutation()
permute_order(new_order)
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
- set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
+ set_layer_name(
+ transpose_advanced_shuffle_layer,
+ target,
+ name + "_index_advanced_shuffle_transpose",
+ source_ir,
+ )
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
- #unfold advanced layer
+ # unfold advanced layer
concat_final_tensor = []
for i in range(0, adv_indx_indices[0]):
current_dim = dim_tensor_list[i]
concat_final_tensor.push_back(curr_dim)
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
for i in range(adv_indx_indices[0], rank):
- if(i not in (adv_indx_indices)):
+ if i not in (adv_indx_indices):
current_dim = dim_tensor_list[i]
concat_final_tensor.append(current_dim)
-
+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
- set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_final_shape_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
- #check this
+ # check this
reshape_layer.set_input(1, concat_final_tensor)
reshape_output = reshape_layer.get_output(0)
-
+
else:
- concat_tensor= []
+ concat_tensor = []
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor.append(curr_dim)
-
+
concat_layer = network.add_concatenation(concat_tensor)
- set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
+ set_layer_name(
+ concat_layer,
+ target,
+ name + "_index_concat_final_shape_layer",
+ source_ir,
+ )
concat_final_tensor = concat_final_shape_layer.get_output(0)
reshape_layer = network.add_shuffle(gather_out)
reshape_layer.setInput(1, concat_final_tensor)
- set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
+ set_layer_name(
+ reshape_layer,
+ target,
+ name + "_index_shuffle_final_shape_layer",
+ source_ir,
+ )
reshape_output = reshape_layer.get_output(0)
return reshape_output
-
-
-
-
-
-
-
-
-
-
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-09-07 03:07:48.137309+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-09-07 03:10:36.917639+00:00
@@ -2,10 +2,11 @@
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from .harness import DispatchTestCase
+
class TestIndexConverter(DispatchTestCase):
def test_index(self):
class TestModule(nn.Module):
def forward(self, x):
@@ -13,6 +14,6 @@
index0 = torch.randint(0, 16, (1, 16))
index1 = torch.randint(0, 16, (1, 16))
out = torch.ops.aten.index(None, None, index0, index1)
inputs = [torch.randn(1, 10)]
- self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})
\ No newline at end of file
+ self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})42798cc to
6b186b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
gs-olive
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Added some style/naming comments, and will run tests on a model which uses this layer, to verify
| for i in range(0, len(index)): | ||
| ind = index[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider rewriting as: for i, ind in enumerate(index)
| permute_order = trt.Permutation(new_order) | ||
| transpose_layer.second_transpose = permute_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can shorten to transpose_layer.second_transpose = tuple(new_order)
| target, | ||
| source_ir, | ||
| name + "index_intermediate", | ||
| trt.ElementWisePROD, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be trt.ElementWiseOperation.PROD
| target, | ||
| source_ir, | ||
| name + "index_sum_intermediate", | ||
| trt.ElementWiseSUM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trt.ElementWiseOperation.SUM
| target, | ||
| source_ir, | ||
| name + "index_intermediate", | ||
| trt.ElementWisePROD, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trt.ElementWiseOperation.PROD
| permute_order = trt.Permutation(new_order) | ||
| transpose_advanced_shuffle_layer.second_transpose = permute_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
permute_order could be replaced with tuple(new_order)
| ) -> TRTTensor: | ||
| adv_indx_indices = [] | ||
| tensor_indices = [] | ||
| _LOGGER.debug(f"The index shape is", index.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index.shape is not valid, since index could be a python list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-25 20:21:10.951934+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2023-09-25 20:24:07.893237+00:00
@@ -94,11 +94,13 @@
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
- assert broadcastable(ind, last_index), "The indices should be broadcastable!"
+ assert broadcastable(
+ ind, last_index
+ ), "The indices should be broadcastable!"
last_index = ind
tensor_indices.append(ind)
if not tensor_indices:
identity_layer = network.add_identity(input)
@@ -177,11 +179,13 @@
_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
- network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], name + "dim_last"
+ network,
+ dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
+ name + "dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
network,
@@ -231,11 +235,13 @@
if (
adv_indx_count
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
):
_LOGGER.debug(f"The indices are continuous in this case")
- concat_tensor_reshape.append(get_trt_tensor(network, -1, name + "dynamic_concat"))
+ concat_tensor_reshape.append(
+ get_trt_tensor(network, -1, name + "dynamic_concat")
+ )
for i in range(0, rank):
if i not in adv_indx_indices:
curr_dim = dim_tensor_list[i]
concat_tensor_reshape.append(curr_dim)
7d215af to
fcbd767
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from harness import DispatchTestCase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to .harness
…adcast and broadcasting cases
… for non continuous indices
742c7c2 to
709d626
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
gs-olive
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few suggestions for input data types and imports
| from torch_tensorrt.fx.converters.converter_utils import ( | ||
| get_positive_dim, | ||
| has_dynamic_shape, | ||
| set_layer_name, | ||
| to_numpy, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch get_positive_dim and to_numpy to the torch_tensorrt.dynamo.conversion.converter_utils version
| ) | ||
|
|
||
|
|
||
| @dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding @enforce_tensor_types( {0: (TRTTensor,)} ), to ensure the input is a TRTTensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
… to_numpy to dynamo converter_utils
gs-olive
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Approved, pending CI
Aten::index converter
#2231