Skip to content

Commit 33c523d

Browse files
ruoqianguonarendasan
authored andcommitted
feat: support aten::extend evaluator
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent b798c7f commit 33c523d

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,26 @@ auto aten_registrations TORCHTRT_UNUSED =
285285
EvalOptions().validSchemas({
286286
"aten::append.t(t[](a!) self, t(c -> *) el) -> (t[](a!))",
287287
})})
288+
.evaluator({c10::Symbol::fromQualString("aten::extend"),
289+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290+
if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) {
291+
auto self = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
292+
auto other = args.at(n->input(1)).IValue()->to<c10::List<c10::IValue>>();
293+
const int64_t other_size = other.size();
294+
295+
for (int64_t i = 0; i < other_size; i++) {
296+
self.push_back(other.get(i));
297+
}
298+
} else {
299+
TORCHTRT_THROW_ERROR(
300+
"Unimplemented data type for aten::extend.t evaluator: "
301+
<< args.at(n->input(0)).IValue()->type()->str() << ", "
302+
<< args.at(n->input(1)).IValue()->type()->str());
303+
}
304+
},
305+
EvalOptions().validSchemas({
306+
"aten::extend.t(t[](a!) self, t[] other) -> ()",
307+
})})
288308
.evaluator({c10::Symbol::fromQualString("aten::neg"),
289309
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
290310
auto el = args.at(n->input(0)).unwrapToInt();

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
303303
ASSERT_TRUE(jit_results[0] == trt_results[0]);
304304
}
305305

306+
TEST(Evaluators, ATenExtendEvaluatesCorrectly) {
307+
const auto graph = R"IR(
308+
graph(%0 : Tensor, %1 : Tensor):
309+
%2 : int = prim::Constant[value=0]()
310+
%3 : Tensor[] = prim::ListConstruct(%0)
311+
%4 : Tensor[] = prim::ListConstruct(%1)
312+
aten::extend(%3, %4)
313+
%5 : Tensor = aten::cat(%3, %2)
314+
return (%5))IR";
315+
316+
auto g = std::make_shared<torch::jit::Graph>();
317+
torch::jit::parseIR(graph, &*g);
318+
319+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
320+
auto in1 = at::randint(1, 10, {5, 4}, {at::kCUDA});
321+
322+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
323+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
324+
325+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
326+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
327+
328+
ASSERT_TRUE(
329+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
330+
}
331+
306332
TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
307333
const auto graph = R"IR(
308334
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)