Skip to content

Commit 587b3a1

Browse files
committed
chore: Increase threshold for activation and lstm testcases
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 3aa8e21 commit 587b3a1

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/core/conversion/converters/test_activation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
4141
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
4242
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
4343

44-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 4e-6));
4545
}
4646

4747
TEST(Converters, ATenTanhConvertsCorrectly) {
@@ -61,7 +61,7 @@ TEST(Converters, ATenTanhConvertsCorrectly) {
6161
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
6262
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
6363

64-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
64+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 7e-6));
6565
}
6666

6767
// TODO: Seems like the IR parser is not handling negative numbers well, need to
@@ -221,5 +221,5 @@ TEST(Converters, ATenGELUConvertsCorrectly) {
221221
// The official tensorrt plugin applies the Gelu activation x * Phi(x), where Phi is the Gaussian cdf, approximated
222222
// by: 0.5 * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x^3))) and the pytorch uses c10::cuda::compat::normcdf to
223223
// compute Phi(x). So there's a difference here.
224-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-4));
224+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 4e-4));
225225
}

tests/core/conversion/converters/test_lstm_cell.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckHidden) {
5252
auto trt_results = trtorch::tests::util::RunGraphEngine(
5353
g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
5454

55-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 1e-5));
5656
}
5757

5858
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
@@ -103,7 +103,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
103103
auto trt_results = trtorch::tests::util::RunGraphEngine(
104104
g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
105105

106-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
106+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 1e-5));
107107
}
108108

109109
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
@@ -146,7 +146,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
146146
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
147147
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
148148

149-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
149+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 1e-5));
150150
}
151151

152152
TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
@@ -189,5 +189,5 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
189189
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
190190
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
191191

192-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
193-
}
192+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 1e-5));
193+
}

0 commit comments

Comments
 (0)