Skip to content

Commit 5643972

Browse files
committed
feat: support aten::eq.str evaluator
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent f99a6ca commit 5643972

File tree

5 files changed

+49
-0
lines changed

5 files changed

+49
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
3333
"aten::eq.bool(bool a, bool b) -> (bool)",
3434
"aten::eq.int(int a, int b) -> (bool)",
3535
"aten::eq.float(float a, float b) -> (bool)",
36+
"aten::eq.str(str a, str b) -> (bool)",
3637
"aten::eq.int_float(int a, float b) -> (bool)",
3738
"aten::eq.float_int(float a, int b) -> (bool)",
3839
}));

core/conversion/evaluators/eval_macros.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@
5757
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
5858
return {}; \
5959
} \
60+
} else if (args.at(n->input(0)).IValue()->isString()) { \
61+
auto a = args.at(n->input(0)).unwrapToString(); \
62+
if (args.at(n->input(1)).IValue()->isString()) { \
63+
auto b = args.at(n->input(1)).unwrapToString(); \
64+
return operation; \
65+
} else { \
66+
TRTORCH_THROW_ERROR( \
67+
"Unimplemented data type for " \
68+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
69+
return {}; \
70+
} \
6071
} else { \
6172
TRTORCH_THROW_ERROR( \
6273
"Unimplemented data type for " \

core/conversion/var/Var.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class Var : torch::CustomClassHolder {
3636
double unwrapToDouble();
3737
bool unwrapToBool(bool default_val);
3838
bool unwrapToBool();
39+
std::string unwrapToString(std::string default_val);
40+
std::string unwrapToString();
3941
c10::Scalar unwrapToScalar(c10::Scalar default_val);
4042
c10::Scalar unwrapToScalar();
4143
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);

core/conversion/var/Var_inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ DEFINE_UNWRAP_TO(at::Tensor, Tensor)
3838
DEFINE_UNWRAP_TO(int64_t, Int)
3939
DEFINE_UNWRAP_TO(double, Double)
4040
DEFINE_UNWRAP_TO(bool, Bool)
41+
DEFINE_UNWRAP_TO(std::string, String)
4142
DEFINE_UNWRAP_TO(c10::Scalar, Scalar)
4243
DEFINE_UNWRAP_TO(c10::List<int64_t>, IntList)
4344
DEFINE_UNWRAP_TO(c10::List<double>, DoubleList)

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,5 +506,39 @@ TEST(Evaluators, ATenIsFloatingPointEvaluatesFalseCorrectly) {
506506
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
507507
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in_trt});
508508

509+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
510+
}
511+
512+
TEST(Evaluators, EqStrResultIsTrueEvaluatesCorrectly) {
513+
const auto graph = R"IR(
514+
graph():
515+
%1 : str = prim::Constant[value="res3"]()
516+
%2 : str = prim::Constant[value="res3"]()
517+
%3 : bool = aten::eq(%1, %2)
518+
return (%3))IR";
519+
520+
auto g = std::make_shared<torch::jit::Graph>();
521+
torch::jit::parseIR(graph, g.get());
522+
523+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
524+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
525+
526+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
527+
}
528+
529+
TEST(Evaluators, EqStrResultIsFalseEvaluatesCorrectly) {
530+
const auto graph = R"IR(
531+
graph():
532+
%1 : str = prim::Constant[value="res3"]()
533+
%2 : str = prim::Constant[value="res4"]()
534+
%3 : bool = aten::eq(%1, %2)
535+
return (%3))IR";
536+
537+
auto g = std::make_shared<torch::jit::Graph>();
538+
torch::jit::parseIR(graph, g.get());
539+
540+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
541+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
542+
509543
ASSERT_TRUE(jit_results[0] == trt_results[0]);
510544
}

0 commit comments

Comments
 (0)