Skip to content

Commit 3a33d33

Browse files
committed
feat: support aten::format evaluator
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent da15fa5 commit 3a33d33

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,23 @@ auto aten_registrations TORCHTRT_UNUSED =
706706
},
707707
EvalOptions().validSchemas({
708708
R"SIG(aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!)))SIG",
709-
})});
709+
})})
710+
.evaluator({c10::Symbol::fromQualString("aten::format"),
711+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
712+
int64_t input_num = n->inputs().size();
713+
std::vector<torch::jit::IValue> stack;
714+
for (auto v : n->inputs()) {
715+
stack.push_back(*args.at(v).IValue());
716+
}
717+
stack.push_back(input_num);
718+
auto& ops = torch::jit::getAllOperatorsFor(c10::Symbol::fromQualString("aten::format"));
719+
auto& aten_format = ops.front();
720+
aten_format->getOperation()(stack);
721+
std::string output;
722+
torch::jit::pop(stack, output);
723+
return output;
724+
},
725+
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})});
710726
} // namespace
711727
} // namespace evaluators
712728
} // namespace conversion

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,4 +579,38 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
579579
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
580580

581581
ASSERT_TRUE(jit_results[0] == trt_results[0]);
582+
}
583+
584+
TEST(Evaluators, AtenFormatEvaluatesCorrectly) {
585+
const auto graph = R"IR(
586+
graph(%x_1 : Tensor, %x_2 : Tensor):
587+
%0 : int = prim::Constant[value=1]()
588+
%1 : str = prim::Constant[value="res{}_{}_"]()
589+
%2 : int = prim::Constant[value=5]()
590+
%2.1 : int = prim::Constant[value=2]()
591+
%3 : str = prim::Constant[value="res5_2_"]()
592+
%4 : str = aten::format(%1, %2, %2.1)
593+
%5 : bool = aten::eq(%3, %4)
594+
%y : Tensor = prim::If(%5)
595+
block0():
596+
%194 : Tensor = aten::add(%x_1, %x_2, %0)
597+
-> (%194)
598+
block1():
599+
%195 : Tensor = aten::sub(%x_1, %x_2, %0)
600+
-> (%195)
601+
return (%y))IR";
602+
auto g = std::make_shared<torch::jit::Graph>();
603+
torch::jit::parseIR(graph, &*g);
604+
605+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
606+
auto in1 = in0.clone();
607+
608+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
609+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
610+
611+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
612+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
613+
614+
ASSERT_TRUE(
615+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
582616
}

0 commit comments

Comments
 (0)