Skip to content

Commit 3a734a2

Browse files
committed
fix(aten::select and aten::var): Fix converters to handle negative axes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b68d4aa commit 3a734a2

File tree

5 files changed

+67
-7
lines changed

5 files changed

+67
-7
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,23 @@ auto select_registrations TRTORCH_UNUSED =
7373
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
7474
auto in = args[0].ITensorOrFreeze(ctx);
7575
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
76-
auto axis = args[1].unwrapToInt();
77-
axis = axis < 0 ? axis + maxDim : axis;
76+
auto dim = args[1].unwrapToInt();
77+
// Handle negative axis by refering to nbDims of input Tensor
78+
dim = dim < 0 ? dim + maxDim : dim;
7879
auto ind = (int32_t)args[2].unwrapToInt();
80+
// Along the specified dimension, handle negative index by subtracting along length of dimension.
81+
ind = ind < 0 ? ind + in->getDimensions().d[dim] : ind;
82+
LOG_DEBUG("Gather input dimensions: " << in->getDimensions());
83+
LOG_DEBUG("Dimension to select: " << dim);
84+
LOG_DEBUG("Index: " << ind);
7985

8086
// index to access needs to be an at::Tensor
8187
at::Tensor indices = torch::tensor({ind}).to(torch::kI32);
8288
auto const_out = tensor_to_const(ctx, indices);
8389

8490
// IGatherLayer takes in input tensor, the indices, and the axis
8591
// of input tensor to take indices from
86-
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
92+
auto gather_layer = ctx->net->addGather(*in, *const_out, dim);
8793
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
8894
auto out = gather_layer->getOutput(0);
8995

@@ -93,7 +99,7 @@ auto select_registrations TRTORCH_UNUSED =
9399
// IShuffleLayer removes redundant dimensions
94100
auto shuffle_layer = ctx->net->addShuffle(*out);
95101
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
96-
shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), axis));
102+
shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), dim));
97103
shuffle_layer->setName(util::node_info(n).c_str());
98104
out = shuffle_layer->getOutput(0);
99105
}

core/conversion/evaluators/eval_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ void recursiveStore(
199199
TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
200200
}
201201
} else {
202-
TRTORCH_ASSERT("Found unsupported data type in arguments for aten::tensor");
202+
TRTORCH_THROW_ERROR("Found unsupported data type in arguments for aten::tensor");
203203
}
204204
}
205205
}

core/lowering/passes/unpack_var.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
1717
%none: None = prim::Constant()
1818
%false: bool = prim::Constant[value=0]()
1919
%0: int = prim::Constant[value=0]()
20+
%dtype: int = prim::Constant[value=6]()
2021
%1: int = prim::Constant[value=1]()
2122
%sqrd: Tensor = aten::mul(%input, %input)
2223
%sqrdmean: Tensor = aten::mean(%sqrd, %dims, %keepdim, %none)
@@ -26,7 +27,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
2627
%varout : Tensor = prim::If(%unbiased)
2728
block0():
2829
%shape: int[] = aten::size(%input)
29-
%shapet: Tensor = aten::tensor(%shape, %0, %none, %false)
30+
%shapet: Tensor = aten::tensor(%shape, %dtype, %none, %false)
3031
%dim: int = prim::ListUnpack(%dims)
3132
%reduceddims: Tensor = aten::select(%shapet, %0, %dim)
3233
%numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none)

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,3 +473,29 @@ TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
473473
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
474474
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
475475
}
476+
477+
TEST(Converters, UnpackVarUnbiasedNegAxisLowersCorrectly) {
478+
const auto graph = R"IR(
479+
graph(%x.1 : Tensor):
480+
%37 : bool = prim::Constant[value=1]()
481+
%53 : int[] = prim::Constant[value=[-1]]()
482+
%69 : Tensor = aten::var(%x.1, %53, %37, %37)
483+
return (%69))IR";
484+
485+
auto in = at::randint(-5, 5, {2, 20, 768}, at::kCUDA).to(at::kFloat);
486+
487+
auto jit_in = at::clone(in);
488+
auto g = std::make_shared<torch::jit::Graph>();
489+
torch::jit::parseIR(graph, g.get());
490+
491+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
492+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
493+
494+
in = at::clone(in);
495+
trtorch::core::lowering::passes::UnpackVar(g);
496+
torch::jit::EliminateCommonSubexpression(g);
497+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
498+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {jit_in});
499+
500+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
501+
}

tests/core/conversion/converters/test_select.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,33 @@ TEST(Converters, ATenSelectIntDimNegativeConvertsCorrectly) {
8585
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
8686
}
8787

88+
TEST(Converters, ATenSelectIntNegIndexConvertsCorrectly) {
89+
const auto graph = R"IR(
90+
graph(%0 : Tensor):
91+
%2 : int = prim::Constant[value=0]()
92+
%3 : int = prim::Constant[value=-1]()
93+
%4 : Tensor = aten::select(%0, %3, %2)
94+
return (%4))IR";
95+
96+
auto g = std::make_shared<torch::jit::Graph>();
97+
98+
torch::jit::parseIR(graph, g.get());
99+
100+
auto in = torch::tensor({2, 20, 768}).to(at::kFloat).to(at::kCUDA);
101+
102+
auto jit_in = at::clone(in);
103+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
104+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
105+
106+
auto trt_in = at::clone(in);
107+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
108+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
109+
110+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
111+
112+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
113+
}
114+
88115
TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
89116
const auto graph = R"IR(
90117
graph(%0 : Tensor):
@@ -437,4 +464,4 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
437464
std::cout << trt_results[1].reshape_as(jit_results[0]) << std::endl;
438465

439466
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
440-
}
467+
}

0 commit comments

Comments
 (0)