Skip to content

Commit 078977c

Browse files
committed
fix: fix the renaming error in squeeze converter
Signed-off-by: Bo Wang <[email protected]>
1 parent 1a22204 commit 078977c

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

core/conversion/converters/impl/squeeze.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pa
2626
}
2727

2828
if (selfDim[dim] != 1) {
29-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self);
29+
auto id_layer = ctx->net->addIdentity(*self);
30+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], id_layer->getOutput(0));
3031

3132
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
3233

tests/core/conversion/converters/test_squeeze.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,32 @@ TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
5353
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
5454
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in, trt_in_add});
5555

56+
ASSERT_TRUE(
57+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
58+
}
59+
60+
TEST(Converters, ATenSqueezeNeedIdentityConvertsCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%0 : Tensor):
63+
%2 : int = prim::Constant[value=1]()
64+
%3 : Tensor = aten::squeeze(%0, %2)
65+
return (%3))IR";
66+
67+
auto g = std::make_shared<torch::jit::Graph>();
68+
torch::jit::parseIR(graph, &*g);
69+
70+
auto in = at::randint(1, 10, {2, 3, 3}, {at::kCUDA});
71+
72+
auto jit_in = at::clone(in);
73+
74+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
75+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
76+
77+
auto trt_in = at::clone(jit_in);
78+
79+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
80+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
81+
5682
ASSERT_TRUE(
5783
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
5884
}

0 commit comments

Comments
 (0)