Skip to content

[libc] Add mpfr tests for fmul. #97376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from Jul 2, 2024
Merged

[libc] Add mpfr tests for fmul. #97376

merged 7 commits into from Jul 2, 2024

Conversation

ghost
Copy link

@ghost ghost commented Jul 2, 2024

No description provided.

@llvmbot llvmbot added the libc label Jul 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2024

@llvm/pr-subscribers-libc

Author: Job Henandez Lara (Jobhdez)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/97376.diff

5 Files Affected:

  • (modified) libc/test/src/math/CMakeLists.txt (+12)
  • (added) libc/test/src/math/FMulTest.h (+121)
  • (added) libc/test/src/math/fmul_test.cpp (+13)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.cpp (+19-42)
  • (modified) libc/utils/MPFRWrapper/MPFRUtils.h (+28-12)
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index c07c6d77fa233..9eda5db1ea2fc 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -1823,6 +1823,18 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fp_bits
 )
 
+add_fp_unittest(
+  fmul_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    fmul_test.cpp
+  HDRS
+    FMulTest.h
+  DEPENDS
+    libc.src.math.fmul
+)
 add_fp_unittest(
   asinhf_test
   NEED_MPFR
diff --git a/libc/test/src/math/FMulTest.h b/libc/test/src/math/FMulTest.h
new file mode 100644
index 0000000000000..864910c29d83f
--- /dev/null
+++ b/libc/test/src/math/FMulTest.h
@@ -0,0 +1,121 @@
+//===-- Utility class to test fmul[f|l] ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+#define LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
+
+#include "src/__support/FPUtil/FPBits.h"
+#include "test/UnitTest/FEnvSafeTest.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+template <typename OutType, typename InType>
+class FmulMPFRTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
+
+  DECLARE_SPECIAL_CONSTANTS(InType)
+
+public:
+  typedef OutType (*FMulFunc)(InType, InType);
+
+  void testFMulMPFR(FMulFunc func) {
+    constexpr int N = 10;
+    mpfr::BinaryInput<InType> INPUTS[N] = {
+        {3.0, 5.0},
+        {0x1.0p1, 0x1.0p-131},
+        {0x1.0p2, 0x1.0p-129},
+        {1.0, 1.0},
+        {-0.0, -0.0},
+        {-0.0, 0.0},
+        {0.0, -0.0},
+        {0x1.0p100, 0x1.0p100},
+        {1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150},
+        {1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150}};
+
+    for (int i = 0; i < N; ++i) {
+      InType x = INPUTS[i].x;
+      InType y = INPUTS[i].y;
+      ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
+                                     func(x, y), 0.5);
+    }
+  }
+
+  void testSpecialInputsMPFR(FMulFunc func) {
+    constexpr int N = 27;
+    mpfr::BinaryInput<InType> INPUTS[N] = {{inf, 0x1.0p-129},
+                                           {0x1.0p-129, inf},
+                                           {inf, 2.0},
+                                           {3.0, inf},
+                                           {0.0, 0.0},
+                                           {neg_inf, aNaN},
+                                           {aNaN, neg_inf},
+                                           {neg_inf, neg_inf},
+                                           {0.0, neg_inf},
+                                           {neg_inf, 0.0},
+                                           {neg_inf, 1.0},
+                                           {1.0, neg_inf},
+                                           {neg_inf, 0x1.0p-129},
+                                           {0x1.0p-129, neg_inf},
+                                           {0.0, 0x1.0p-129},
+                                           {inf, 0.0},
+                                           {0.0, inf},
+                                           {0.0, aNaN},
+                                           {2.0, aNaN},
+                                           {0x1.0p-129, aNaN},
+                                           {inf, aNaN},
+                                           {aNaN, aNaN},
+                                           {0.0, sNaN},
+                                           {2.0, sNaN},
+                                           {0x1.0p-129, sNaN},
+                                           {inf, sNaN},
+                                           {sNaN, sNaN}};
+
+    for (int i = 0; i < N; ++i) {
+      InType x = INPUTS[i].x;
+      InType y = INPUTS[i].y;
+      ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
+                                     func(x, y), 0.5);
+    }
+  }
+
+  void testNormalRange(FMulFunc func) {
+    using FPBits = LIBC_NAMESPACE::fputil::FPBits<InType>;
+    using StorageType = typename FPBits::StorageType;
+    static constexpr StorageType MAX_NORMAL = FPBits::max_normal().uintval();
+    static constexpr StorageType MIN_NORMAL = FPBits::min_normal().uintval();
+
+    constexpr StorageType COUNT = 10'001;
+    constexpr StorageType STEP = (MAX_NORMAL - MIN_NORMAL) / COUNT;
+    for (int signs = 0; signs < 4; ++signs) {
+      for (StorageType v = MIN_NORMAL, w = MAX_NORMAL;
+           v <= MAX_NORMAL && w >= MIN_NORMAL; v += STEP, w -= STEP) {
+        InType x = FPBits(v).get_val(), y = FPBits(w).get_val();
+        if (signs % 2 == 1) {
+          x = -x;
+        }
+        if (signs >= 2) {
+          y = -y;
+        }
+
+        mpfr::BinaryInput<InType> input{x, y};
+        ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, input, func(x, y),
+                                       0.5);
+      }
+    }
+  }
+};
+
+#define LIST_FMUL_MPFR_TESTS(OutType, InType, func)                            \
+  using LlvmLibcFmulTest = FmulMPFRTest<OutType, InType>;                      \
+  TEST_F(LlvmLibcFmulTest, MulMpfr) { testFMulMPFR(&func); }                   \
+  TEST_F(LlvmLibcFmulTest, NanInfMpfr) { testSpecialInputsMPFR(&func); }       \
+  TEST_F(LlvmLibcFmulTest, NormalRange) { testNormalRange(&func); }
+
+#endif // LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
diff --git a/libc/test/src/math/fmul_test.cpp b/libc/test/src/math/fmul_test.cpp
new file mode 100644
index 0000000000000..16eaa1a818daf
--- /dev/null
+++ b/libc/test/src/math/fmul_test.cpp
@@ -0,0 +1,13 @@
+//===-- Unittests for fmul-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "FMulTest.h"
+
+#include "src/math/fmul.h"
+
+LIST_FMUL_MPFR_TESTS(float, double, LIBC_NAMESPACE::fmul)
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp
index 379a631a356a3..cf4ded5fb66b4 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp
@@ -487,6 +487,12 @@ class MPFRNumber {
     return result;
   }
 
+  MPFRNumber fmul(const MPFRNumber &b) {
+    MPFRNumber result(*this);
+    mpfr_mul(result.value, value, b.value, mpfr_rounding);
+    return result;
+  }
+
   cpp::string str() const {
     // 200 bytes should be more than sufficient to hold a 100-digit number
     // plus additional bytes for the decimal point, '-' sign etc.
@@ -738,6 +744,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
     return inputX.hypot(inputY);
   case Operation::Pow:
     return inputX.pow(inputY);
+  case Operation::Fmul:
+    return inputX.fmul(inputY);
   default:
     __builtin_unreachable();
   }
@@ -947,21 +955,9 @@ explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
                                           float, double, RoundingMode);
 template void explain_binary_operation_one_output_error(
     Operation, const BinaryInput<double> &, double, double, RoundingMode);
-template void
-explain_binary_operation_one_output_error(Operation,
-                                          const BinaryInput<long double> &,
-                                          long double, double, RoundingMode);
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<float16> &, float16, double, RoundingMode);
-template void
-explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
-                                          float16, double, RoundingMode);
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<double> &, float16, double, RoundingMode);
-template void explain_binary_operation_one_output_error(
-    Operation, const BinaryInput<long double> &, float16, double, RoundingMode);
-#endif
+template void explain_binary_operation_one_output_error<long double>(
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode);
 
 template <typename InputType, typename OutputType>
 void explain_ternary_operation_one_output_error(
@@ -1109,7 +1105,7 @@ bool compare_binary_operation_one_output(Operation op,
                                          OutputType libc_result,
                                          double ulp_tolerance,
                                          RoundingMode rounding) {
-  unsigned int precision = get_precision<InputType>(ulp_tolerance);
+  unsigned int precision = get_precision<T>(ulp_tolerance);
   MPFRNumber mpfr_result =
       binary_operation_one_output(op, input.x, input.y, precision, rounding);
   double ulp = mpfr_result.ulp(libc_result);
@@ -1117,32 +1113,13 @@ bool compare_binary_operation_one_output(Operation op,
   return (ulp <= ulp_tolerance);
 }
 
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float> &,
-                                                  float, double, RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<double> &,
-                                                  double, double, RoundingMode);
-template bool
-compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
-                                    long double, double, RoundingMode);
-#ifdef LIBC_TYPES_HAS_FLOAT16
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float16> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<float> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool compare_binary_operation_one_output(Operation,
-                                                  const BinaryInput<double> &,
-                                                  float16, double,
-                                                  RoundingMode);
-template bool
-compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
-                                    float16, double, RoundingMode);
-#endif
+template bool compare_binary_operation_one_output<float>(
+    Operation, const BinaryInput<float> &, float, double, RoundingMode);
+template bool compare_binary_operation_one_output<double>(
+    Operation, const BinaryInput<double> &, double, double, RoundingMode);
+template bool compare_binary_operation_one_output<long double>(
+    Operation, const BinaryInput<long double> &, long double, double,
+    RoundingMode);
 
 template <typename InputType, typename OutputType>
 bool compare_ternary_operation_one_output(Operation op,
diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h
index 11e323bf6881d..7621866e6d730 100644
--- a/libc/utils/MPFRWrapper/MPFRUtils.h
+++ b/libc/utils/MPFRWrapper/MPFRUtils.h
@@ -41,6 +41,7 @@ enum class Operation : int {
   Exp10,
   Expm1,
   Floor,
+  Fmul,
   Log,
   Log2,
   Log10,
@@ -147,6 +148,14 @@ template <typename T> struct IsTernaryInput<TernaryInput<T>> {
   static constexpr bool VALUE = true;
 };
 
+template <typename T> struct IsBinaryInput {
+  static constexpr bool VALUE = false;
+};
+
+template <typename T> struct IsBinaryInput<BinaryInput<T>> {
+  static constexpr bool VALUE = true;
+};
+
 template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
 
 template <typename T>
@@ -237,12 +246,14 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
   bool is_silent() const override { return silent; }
 
 private:
-  template <typename T, typename U> bool match(T in, U out) {
+  template <typename InType, typename OutType>
+  bool match(InType in, OutType out) {
     return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
                                                  rounding);
   }
 
-  template <typename T> bool match(T in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  bool match(InType in, const BinaryOutput<InType> &out) {
     return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -253,30 +264,33 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                                rounding);
   }
 
-  template <typename T>
-  bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  bool match(BinaryInput<InType> in, const BinaryOutput<InType> &out) {
     return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T, typename U>
-  bool match(const TernaryInput<T> &in, U out) {
+  template <typename InType, typename OutType>
+  bool match(const TernaryInput<InType> &in, OutType out) {
     return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T, typename U> void explain_error(T in, U out) {
+  template <typename InType, typename OutType>
+  void explain_error(InType in, OutType out) {
     explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
                                                 rounding);
   }
 
-  template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  void explain_error(InType in, const BinaryOutput<InType> &out) {
     explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                               rounding);
   }
 
-  template <typename T>
-  void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
+  template <typename InType>
+  void explain_error(const BinaryInput<InType> &in,
+                     const BinaryOutput<InType> &out) {
     explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -287,8 +301,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
                                               rounding);
   }
 
-  template <typename T, typename U>
-  void explain_error(const TernaryInput<T> &in, U out) {
+  template <typename InType, typename OutType>
+  void explain_error(const TernaryInput<InType> &in, OutType out) {
     explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
                                                rounding);
   }
@@ -311,6 +325,8 @@ constexpr bool is_valid_operation() {
       (op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<
            typename internal::MakeScalarInput<InputType>::type> &&
+       cpp::is_floating_point_v<OutputType>) ||
+      (op == Operation::Fmul && internal::IsBinaryInput<InputType>::VALUE &&
        cpp::is_floating_point_v<OutputType>);
   if (IS_NARROWING_OP)
     return true;

@lntue lntue changed the title [libc] Add mpfr tests [libc] Add mpfr tests for fmul. Jul 2, 2024
@lntue lntue self-requested a review July 2, 2024 02:45
Copy link

github-actions bot commented Jul 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@lntue lntue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@lntue lntue merged commit 6f60d2b into llvm:main Jul 2, 2024
4 of 5 checks passed
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants