diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 63c7713d01..58f0aeae61 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -267,6 +267,7 @@ auto select_registrations TORCHTRT_UNUSED = .pattern( {"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627 auto in = args[0].ITensorOrFreeze(ctx); auto ts = args[1].IValue()->toListRef(); @@ -471,7 +472,7 @@ auto select_registrations TORCHTRT_UNUSED = } } auto concat_final_shape_layer = - ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size()); + ctx->net->addConcatenation(concat_final_tensors.data(), concat_final_tensors.size()); auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out); unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0)); reshape_output = unfold_advanced_shuffle_layer->getOutput(0); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index c84446ad68..1285c24dd6 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -921,6 +921,37 @@ TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenIndexTensorIdxsNoneConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2, %5) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {4, 8, 8, 4}, {at::kCUDA}); + auto index0 = at::full({4, 13, 1}, 1, {at::kCUDA}).to(torch::kLong); + auto index1 = at::full({4, 13, 1}, 2, {at::kCUDA}).to(torch::kLong); + auto index2 = at::full({4, 13, 1}, 3, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenUnbindConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):