diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index bebcf794c8..ec7fdc79fb 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -398,7 +398,7 @@ void ConvertBlockToNetDef( EvaluateConditionalBlock(ctx, n); } else if (to_eval) { auto eval = EvaluateNode(ctx, n); - if (eval) { + if (eval && n->outputs().size() > 0) { if (n->outputs().size() > 1) { // For ListUnpack scenario if (eval.value().isTuple()) { auto eval_list = eval.value().toTuple(); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 8adac76c38..e356dbf6be 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -285,6 +285,27 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({ "aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))", })}) + .evaluator({c10::Symbol::fromQualString("aten::extend"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) { + auto self = args.at(n->input(0)).IValue()->to>(); + auto other = args.at(n->input(1)).IValue()->to>(); + const int64_t other_size = other.size(); + + for (int64_t i = 0; i < other_size; i++) { + self.push_back(other.get(i)); + } + return self; + } else { + TORCHTRT_THROW_ERROR( + "Unimplemented data type for aten::extend.t evaluator: " + << args.at(n->input(0)).IValue()->type()->str() << ", " + << args.at(n->input(1)).IValue()->type()->str()); + } + }, + EvalOptions().validSchemas({ + "aten::extend.t(t[](a!) self, t[] other) -> ()", + })}) .evaluator({c10::Symbol::fromQualString("aten::neg"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToInt(); diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index 41eb1e06be..c119c221f5 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -302,6 +302,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) { ASSERT_TRUE(jit_results[0] == trt_results[0]); } +TEST(Evaluators, ATenExtendEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : Tensor[] = prim::ListConstruct(%0) + %4 : Tensor[] = prim::ListConstruct(%1) + aten::extend(%3, %4) + %5 : Tensor = aten::cat(%3, %2) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA}); + auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor):