@@ -53,6 +53,32 @@ TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) {
53
53
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
54
54
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in, trt_in_add});
55
55
56
+ ASSERT_TRUE (
57
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
58
+ }
59
+
60
+ TEST (Converters, ATenSqueezeNeedIdentityConvertsCorrectly) {
61
+ const auto graph = R"IR(
62
+ graph(%0 : Tensor):
63
+ %2 : int = prim::Constant[value=1]()
64
+ %3 : Tensor = aten::squeeze(%0, %2)
65
+ return (%3))IR" ;
66
+
67
+ auto g = std::make_shared<torch::jit::Graph>();
68
+ torch::jit::parseIR (graph, &*g);
69
+
70
+ auto in = at::randint (1 , 10 , {2 , 3 , 3 }, {at::kCUDA });
71
+
72
+ auto jit_in = at::clone (in);
73
+
74
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
75
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
76
+
77
+ auto trt_in = at::clone (jit_in);
78
+
79
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
80
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
81
+
56
82
ASSERT_TRUE (
57
83
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
58
84
}
0 commit comments