@@ -52,7 +52,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckHidden) {
52
52
auto trt_results = trtorch::tests::util::RunGraphEngine (
53
53
g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
54
54
55
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
55
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 1e-5 ));
56
56
}
57
57
58
58
TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
@@ -103,7 +103,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithBiasCheckCell) {
103
103
auto trt_results = trtorch::tests::util::RunGraphEngine (
104
104
g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
105
105
106
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
106
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 1e-5 ));
107
107
}
108
108
109
109
TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
@@ -146,7 +146,7 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckHidden) {
146
146
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
147
147
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
148
148
149
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
149
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 1e-5 ));
150
150
}
151
151
152
152
TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
@@ -189,5 +189,5 @@ TEST(Converters, ATenLSTMCellConvertsCorrectlyWithoutBiasCheckCell) {
189
189
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
190
190
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
191
191
192
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
193
- }
192
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 1e-5 ));
193
+ }
0 commit comments