Skip to content

Commit bb3046a

Browse files
committed
feat: support aten::div.Tensor_mode
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent c77def0 commit bb3046a

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,43 @@ auto element_wise_registrations TORCHTRT_UNUSED =
323323
div->setName(util::node_info(n).c_str());
324324
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
325325

326+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
327+
return true;
328+
}})
329+
.pattern({"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)",
330+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
331+
// Should implement self / other
332+
auto self = args[0].ITensorOrFreeze(ctx);
333+
auto other = args[1].ITensorOrFreeze(ctx);
334+
std::string rounding_mode = "default";
335+
if (args[2].isIValue() && args[2].IValue()->isString()) {
336+
rounding_mode = args[2].unwrapToString();
337+
}
338+
nvinfer1::ILayer* div = nullptr;
339+
if (rounding_mode == "floor") {
340+
div = add_elementwise(
341+
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
342+
} else if (rounding_mode == "trunc") {
343+
// trunc = floor(abs(div)) * sign(div)
344+
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
345+
auto abs = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kABS);
346+
auto floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
347+
auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN);
348+
div = add_elementwise(
349+
ctx,
350+
nvinfer1::ElementWiseOperation::kPROD,
351+
floor->getOutput(0),
352+
sign->getOutput(0),
353+
util::node_info(n));
354+
} else {
355+
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
356+
}
357+
358+
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
359+
360+
div->setName(util::node_info(n).c_str());
361+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
362+
326363
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
327364
return true;
328365
}})

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@ void pointwise_test_helper(
99
bool singleInput,
1010
bool dynamicInput = false,
1111
std::vector<int64_t> shape1 = {5},
12-
std::vector<int64_t> shape2 = {5}) {
12+
std::vector<int64_t> shape2 = {5},
13+
bool negative_input = false) {
1314
auto g = std::make_shared<torch::jit::Graph>();
1415
torch::jit::parseIR(graph_ir, g.get());
1516

1617
// singleInput case is enabled when elementwise operation is performed
1718
// with an input and a constant embedded in graph
1819
std::vector<at::Tensor> torch_inputs;
19-
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
20+
if (negative_input) {
21+
torch_inputs.push_back(at::randint(-5, 5, shape1, {at::kCUDA}));
22+
} else {
23+
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
24+
}
2025
if (!singleInput) {
2126
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
2227
}
@@ -141,6 +146,45 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
141146
pointwise_test_helper(graph, true);
142147
}
143148

149+
TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
150+
const auto graph = R"IR(
151+
graph(%0 : Tensor, %1 : Tensor):
152+
%3 : str = prim::Constant[value="floor"]()
153+
%2 : Tensor = aten::div(%0, %1, %3)
154+
return (%2))IR";
155+
pointwise_test_helper(graph, false, false, {5}, {5}, true);
156+
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
157+
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
158+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
159+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
160+
}
161+
162+
TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
163+
const auto graph = R"IR(
164+
graph(%0 : Tensor, %1 : Tensor):
165+
%3 : str = prim::Constant[value="trunc"]()
166+
%2 : Tensor = aten::div(%0, %1, %3)
167+
return (%2))IR";
168+
pointwise_test_helper(graph, false, false, {5}, {5}, true);
169+
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
170+
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
171+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
172+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
173+
}
174+
175+
TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
176+
const auto graph = R"IR(
177+
graph(%0 : Tensor, %1 : Tensor):
178+
%3 : None = prim::Constant()
179+
%2 : Tensor = aten::div(%0, %1, %3)
180+
return (%2))IR";
181+
pointwise_test_helper(graph, false, false, {5}, {5}, true);
182+
pointwise_test_helper(graph, false, false, {3, 4}, {4}, true);
183+
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
184+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
185+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
186+
}
187+
144188
TEST(Converters, ATenPowTensorConvertsCorrectly) {
145189
const auto graph = R"IR(
146190
graph(%x.1 : Tensor, %x2.1 : Tensor):

0 commit comments

Comments
 (0)