diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 8fb572b850..f5d9e78c2d 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -285,6 +285,31 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({ "aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))", })}) + .evaluator({c10::Symbol::fromQualString("aten::extend"), + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) { + c10::IValue* self_ptr = args.at(n->input(0)).IValueMut(); + auto self = self_ptr->to>(); + auto other = args.at(n->input(1)).IValue()->to>(); + const int64_t other_size = other.size(); + + // Modify value in place + for (int64_t i = 0; i < other_size; i++) { + self.push_back(other.get(i)); + } + + *self_ptr = c10::IValue(self); + return {}; + } else { + TORCHTRT_THROW_ERROR( + "Unimplemented data type for aten::extend.t evaluator: " + << args.at(n->input(0)).IValue()->type()->str() << ", " + << args.at(n->input(1)).IValue()->type()->str()); + } + }, + EvalOptions().validSchemas({ + "aten::extend.t(t[](a!) self, t[] other) -> ()", + })}) .evaluator({c10::Symbol::fromQualString("aten::neg"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToInt(); diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 71dbe63e51..027f8f1cb8 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -13,7 +13,7 @@ Var::Var() { type_ = Type::kNone; } -Var::Var(const torch::jit::IValue* p) : type_(Type::kIValue) { +Var::Var(torch::jit::IValue* p) : type_(Type::kIValue) { ptr_.ivalue = p; } @@ -56,7 +56,7 @@ Var& Var::operator=(const Var& a) { return (*this); } -Var& Var::operator=(const torch::jit::IValue* in) { +Var& Var::operator=(torch::jit::IValue* in) { ptr_.ivalue = in; type_ = Type::kIValue; return (*this); @@ -116,6 +116,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { } const torch::jit::IValue* Var::IValue() const { + return IValueMut(); +} + +torch::jit::IValue* Var::IValueMut() const { TORCHTRT_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name()); if (type_ == Type::kIValue) { return ptr_.ivalue; diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index 75f1d7e558..4a81bbedd1 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -17,13 +17,14 @@ class Var : torch::CustomClassHolder { enum Type { kITensor, kIValue, kNone }; Var(); - Var(const torch::jit::IValue* p); + Var(torch::jit::IValue* p); Var(nvinfer1::ITensor* p); Var(const Var& a); Var& operator=(const Var& a); - Var& operator=(const torch::jit::IValue* in); + Var& operator=(torch::jit::IValue* in); Var& operator=(nvinfer1::ITensor* in); const torch::jit::IValue* IValue() const; + torch::jit::IValue* IValueMut() const; nvinfer1::ITensor* ITensor() const; // TODO: Can we consolidate this in a way that prevents requesting invalid @@ -63,7 +64,7 @@ class Var : torch::CustomClassHolder { private: union VarContainer { - const torch::jit::IValue* ivalue; + torch::jit::IValue* ivalue; nvinfer1::ITensor* tensor; void* none; }; diff --git a/tests/core/conversion/evaluators/test_aten_evaluators.cpp b/tests/core/conversion/evaluators/test_aten_evaluators.cpp index ef8282a2ea..6d46568d4b 100644 --- a/tests/core/conversion/evaluators/test_aten_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_aten_evaluators.cpp @@ -303,6 +303,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) { ASSERT_TRUE(jit_results[0] == trt_results[0]); } +TEST(Evaluators, ATenExtendEvaluatesCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : Tensor[] = prim::ListConstruct(%0) + %4 : Tensor[] = prim::ListConstruct(%1) + aten::extend(%3, %4) + %5 : Tensor = aten::cat(%3, %2) + return (%5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA}); + auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor):