Skip to content

Commit c4fdfcb

Browse files
committed
feat(//core/conversion/evaluators): aten::pow support
Adds support for the following aten::pow variants in the evaluator library ``` "aten::pow.int(int a, int b) -> (float)", "aten::pow.float(float a, float b) -> (float)", "aten::pow.int_float(int a, float b) -> (float)", "aten::pow.float_int(float a, int b) -> (float)", ``` Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c5c5c47 commit c4fdfcb

File tree

3 files changed

+129
-1
lines changed

3 files changed

+129
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <math.h>
2+
13
#include "ATen/core/List.h"
24
#include "ATen/core/functional.h"
35
#include "ATen/core/ivalue.h"
@@ -98,10 +100,21 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
98100
"aten::ge.float_int(float a, int b) -> (bool)",
99101
}));
100102

103+
DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
104+
pow,
105+
"aten::pow",
106+
pow(a,b),
107+
std::set<std::string>({
108+
"aten::pow.int(int a, int b) -> (float)",
109+
"aten::pow.float(float a, float b) -> (float)",
110+
"aten::pow.int_float(int a, float b) -> (float)",
111+
"aten::pow.float_int(float a, int b) -> (float)",
112+
}));
113+
101114
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
102115
and,
103116
"aten::__and__",
104-
a&& b,
117+
a && b,
105118
bool,
106119
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
107120
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});

core/conversion/evaluators/eval_macros.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,53 @@
7777
}, \
7878
EvalOptions().validSchemas(schemas)});
7979

80+
#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
81+
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82+
{c10::Symbol::fromQualString(node_kind), \
83+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
84+
if (args.at(n->input(0)).IValue()->isInt()) { \
85+
auto a = args.at(n->input(0)).unwrapToInt(); \
86+
if (args.at(n->input(1)).IValue()->isInt()) { \
87+
auto b = args.at(n->input(1)).unwrapToInt(); \
88+
return operation; \
89+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
90+
auto b = args.at(n->input(1)).unwrapToDouble(); \
91+
return operation; \
92+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
93+
auto b = args.at(n->input(1)).unwrapToBool(); \
94+
return operation; \
95+
} else { \
96+
TORCHTRT_THROW_ERROR( \
97+
"Unimplemented data type for " \
98+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
99+
return {}; \
100+
} \
101+
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
102+
auto a = args.at(n->input(0)).unwrapToDouble(); \
103+
if (args.at(n->input(1)).IValue()->isInt()) { \
104+
auto b = args.at(n->input(1)).unwrapToInt(); \
105+
return operation; \
106+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
107+
auto b = args.at(n->input(1)).unwrapToDouble(); \
108+
return operation; \
109+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
110+
auto b = args.at(n->input(1)).unwrapToBool(); \
111+
return operation; \
112+
} else { \
113+
TORCHTRT_THROW_ERROR( \
114+
"Unimplemented data type for " \
115+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
116+
return {}; \
117+
} \
118+
} else { \
119+
TORCHTRT_THROW_ERROR( \
120+
"Unimplemented data type for " \
121+
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
122+
return {}; \
123+
} \
124+
}, \
125+
EvalOptions().validSchemas(schemas)});
126+
80127
#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
81128
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82129
{c10::Symbol::fromQualString(node_name), \

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,5 +726,73 @@ TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
726726
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
727727
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
728728

729+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
730+
}
731+
732+
TEST(Evaluators, PowIntEvaluatesCorrectly) {
733+
const auto graph = R"IR(
734+
graph():
735+
%1 : int = prim::Constant[value=9]()
736+
%2 : int = prim::Constant[value=4]()
737+
%3 : float = aten::pow(%1, %2)
738+
return (%3))IR";
739+
740+
auto g = std::make_shared<torch::jit::Graph>();
741+
torch::jit::parseIR(graph, g.get());
742+
743+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
744+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
745+
746+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
747+
}
748+
749+
TEST(Evaluators, PowFloatEvaluatesCorrectly) {
750+
const auto graph = R"IR(
751+
graph():
752+
%1 : float = prim::Constant[value=9.5]()
753+
%2 : float = prim::Constant[value=4.5]()
754+
%3 : float = aten::pow(%1, %2)
755+
return (%3))IR";
756+
757+
auto g = std::make_shared<torch::jit::Graph>();
758+
torch::jit::parseIR(graph, g.get());
759+
760+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
761+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
762+
763+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
764+
}
765+
766+
TEST(Evaluators, PowIntFloatEvaluatesCorrectly) {
767+
const auto graph = R"IR(
768+
graph():
769+
%1 : int = prim::Constant[value=9]()
770+
%2 : float = prim::Constant[value=4.5]()
771+
%3 : float = aten::pow(%1, %2)
772+
return (%3))IR";
773+
774+
auto g = std::make_shared<torch::jit::Graph>();
775+
torch::jit::parseIR(graph, g.get());
776+
777+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
778+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
779+
780+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
781+
}
782+
783+
TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
784+
const auto graph = R"IR(
785+
graph():
786+
%1 : float = prim::Constant[value=9.5]()
787+
%2 : int = prim::Constant[value=4]()
788+
%3 : float = aten::pow(%1, %2)
789+
return (%3))IR";
790+
791+
auto g = std::make_shared<torch::jit::Graph>();
792+
torch::jit::parseIR(graph, g.get());
793+
794+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
795+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
796+
729797
ASSERT_TRUE(jit_results[0] == trt_results[0]);
730798
}

0 commit comments

Comments
 (0)