diff --git a/tests/util/util.cpp b/tests/util/util.cpp index f58291d2a8..13d0d18566 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -5,25 +5,30 @@ namespace torch_tensorrt { namespace tests { namespace util { -bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5) { - LOG_GRAPH(a << std::endl << b << std::endl); - auto a_float = a.toType(at::kFloat); - auto b_float = b.toType(at::kFloat); +bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5) { + std::ostringstream ss; + ss << computed_tensor << std::endl << gt_tensor << std::endl; + ss << " atol: " << atol << " rtol: " << rtol << std::endl; - auto diff = a_float - b_float; - auto result = diff.abs().max().item() - (atol + rtol * b.abs().max().item()); + LOG_GRAPH(ss.str()); + auto computed_tensor_float = computed_tensor.toType(at::kFloat); + auto gt_tensor_float = gt_tensor.toType(at::kFloat); - std::cout << "Max Difference: " << result << std::endl; - std::cout << "Acceptable Threshold: " << threshold << std::endl; + auto diff = computed_tensor_float - gt_tensor_float; + auto result = diff.abs().max().item(); + auto threshold = atol + (rtol * gt_tensor.abs().max().item()); + + LOG_GRAPH(std::string("Max Difference: ") + std::to_string(result)); + LOG_GRAPH(std::string("Acceptable Threshold: ") + std::to_string(threshold)); return result <= threshold; } -bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { - LOG_GRAPH(a << std::endl << b << std::endl); - std::cout << "Max Difference: " << (a - b).abs().max().item() << std::endl; +bool exactlyEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor) { + LOG_GRAPH(computed_tensor << std::endl << gt_tensor << std::endl); + std::cout << "Max Difference: " << (computed_tensor - gt_tensor).abs().max().item() << std::endl; - return (a - b).abs().max().item() == 0.f; + return (computed_tensor - gt_tensor).abs().max().item() == 0.f; } } // namespace util diff --git a/tests/util/util.h b/tests/util/util.h index b795667fa4..f39e2a5766 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -11,7 +11,7 @@ namespace torch_tensorrt { namespace tests { namespace util { -bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5); +bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);