Skip to content

[libc] Add mpfr tests for fmul. #97376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,18 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
fmul_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
fmul_test.cpp
HDRS
FMulTest.h
DEPENDS
libc.src.math.fmul
)
add_fp_unittest(
asinhf_test
NEED_MPFR
Expand Down
121 changes: 121 additions & 0 deletions libc/test/src/math/FMulTest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
//===-- Utility class to test fmul[f|l] -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H

#include "src/__support/FPUtil/FPBits.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

template <typename OutType, typename InType>
class FmulMPFRTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {

DECLARE_SPECIAL_CONSTANTS(InType)

public:
typedef OutType (*FMulFunc)(InType, InType);

void testFMulMPFR(FMulFunc func) {
constexpr int N = 10;
mpfr::BinaryInput<InType> INPUTS[N] = {
{3.0, 5.0},
{0x1.0p1, 0x1.0p-131},
{0x1.0p2, 0x1.0p-129},
{1.0, 1.0},
{-0.0, -0.0},
{-0.0, 0.0},
{0.0, -0.0},
{0x1.0p100, 0x1.0p100},
{1.0, 1.0 + 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150},
{1.0, 0x1.0p-128 + 0x1.0p-149 + 0x1.0p-150}};

for (int i = 0; i < N; ++i) {
InType x = INPUTS[i].x;
InType y = INPUTS[i].y;
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
func(x, y), 0.5);
}
}

void testSpecialInputsMPFR(FMulFunc func) {
constexpr int N = 27;
mpfr::BinaryInput<InType> INPUTS[N] = {{inf, 0x1.0p-129},
{0x1.0p-129, inf},
{inf, 2.0},
{3.0, inf},
{0.0, 0.0},
{neg_inf, aNaN},
{aNaN, neg_inf},
{neg_inf, neg_inf},
{0.0, neg_inf},
{neg_inf, 0.0},
{neg_inf, 1.0},
{1.0, neg_inf},
{neg_inf, 0x1.0p-129},
{0x1.0p-129, neg_inf},
{0.0, 0x1.0p-129},
{inf, 0.0},
{0.0, inf},
{0.0, aNaN},
{2.0, aNaN},
{0x1.0p-129, aNaN},
{inf, aNaN},
{aNaN, aNaN},
{0.0, sNaN},
{2.0, sNaN},
{0x1.0p-129, sNaN},
{inf, sNaN},
{sNaN, sNaN}};

for (int i = 0; i < N; ++i) {
InType x = INPUTS[i].x;
InType y = INPUTS[i].y;
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, INPUTS[i],
func(x, y), 0.5);
}
}

void testNormalRange(FMulFunc func) {
using FPBits = LIBC_NAMESPACE::fputil::FPBits<InType>;
using StorageType = typename FPBits::StorageType;
static constexpr StorageType MAX_NORMAL = FPBits::max_normal().uintval();
static constexpr StorageType MIN_NORMAL = FPBits::min_normal().uintval();

constexpr StorageType COUNT = 10'001;
constexpr StorageType STEP = (MAX_NORMAL - MIN_NORMAL) / COUNT;
for (int signs = 0; signs < 4; ++signs) {
for (StorageType v = MIN_NORMAL, w = MAX_NORMAL;
v <= MAX_NORMAL && w >= MIN_NORMAL; v += STEP, w -= STEP) {
InType x = FPBits(v).get_val(), y = FPBits(w).get_val();
if (signs % 2 == 1) {
x = -x;
}
if (signs >= 2) {
y = -y;
}

mpfr::BinaryInput<InType> input{x, y};
ASSERT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fmul, input, func(x, y),
0.5);
}
}
}
};

#define LIST_FMUL_MPFR_TESTS(OutType, InType, func) \
using LlvmLibcFmulTest = FmulMPFRTest<OutType, InType>; \
TEST_F(LlvmLibcFmulTest, MulMpfr) { testFMulMPFR(&func); } \
TEST_F(LlvmLibcFmulTest, NanInfMpfr) { testSpecialInputsMPFR(&func); } \
TEST_F(LlvmLibcFmulTest, NormalRange) { testNormalRange(&func); }

#endif // LLVM_LIBC_TEST_SRC_MATH_FMULTEST_H
13 changes: 13 additions & 0 deletions libc/test/src/math/fmul_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for fmul-------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//

#include "FMulTest.h"

#include "src/math/fmul.h"

LIST_FMUL_MPFR_TESTS(float, double, LIBC_NAMESPACE::fmul)
15 changes: 15 additions & 0 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ class MPFRNumber {
return result;
}

MPFRNumber fmul(const MPFRNumber &b) {
MPFRNumber result(*this);
mpfr_mul(result.value, value, b.value, mpfr_rounding);
return result;
}

cpp::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
// plus additional bytes for the decimal point, '-' sign etc.
Expand Down Expand Up @@ -738,6 +744,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
return inputX.hypot(inputY);
case Operation::Pow:
return inputX.pow(inputY);
case Operation::Fmul:
return inputX.fmul(inputY);
default:
__builtin_unreachable();
}
Expand Down Expand Up @@ -951,6 +959,9 @@ template void
explain_binary_operation_one_output_error(Operation,
const BinaryInput<long double> &,
long double, double, RoundingMode);

template void explain_binary_operation_one_output_error(
Operation, const BinaryInput<double> &, float, double, RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
template void explain_binary_operation_one_output_error(
Operation, const BinaryInput<float16> &, float16, double, RoundingMode);
Expand Down Expand Up @@ -1126,6 +1137,10 @@ template bool compare_binary_operation_one_output(Operation,
template bool
compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
long double, double, RoundingMode);

template bool compare_binary_operation_one_output(Operation,
const BinaryInput<double> &,
float, double, RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
template bool compare_binary_operation_one_output(Operation,
const BinaryInput<float16> &,
Expand Down
18 changes: 11 additions & 7 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ enum class Operation : int {
Fmod,
Hypot,
Pow,
Fmul,
EndBinaryOperationsSingleOutput,

// Operations which take two floating point numbers of the same type as
Expand Down Expand Up @@ -237,7 +238,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
bool is_silent() const override { return silent; }

private:
template <typename T, typename U> bool match(T in, U out) {
template <typename InType, typename OutType>
bool match(InType in, OutType out) {
return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
rounding);
}
Expand All @@ -259,13 +261,14 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
rounding);
}

template <typename T, typename U>
bool match(const TernaryInput<T> &in, U out) {
template <typename InType, typename OutType>
bool match(const TernaryInput<InType> &in, OutType out) {
return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
rounding);
}

template <typename T, typename U> void explain_error(T in, U out) {
template <typename InType, typename OutType>
void explain_error(InType in, OutType out) {
explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
rounding);
}
Expand All @@ -287,8 +290,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
rounding);
}

template <typename T, typename U>
void explain_error(const TernaryInput<T> &in, U out) {
template <typename InType, typename OutType>
void explain_error(const TernaryInput<InType> &in, OutType out) {
explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
rounding);
}
Expand All @@ -304,7 +307,8 @@ constexpr bool is_valid_operation() {
(op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
cpp::is_floating_point_v<OutputType> &&
sizeof(OutputType) <= sizeof(InputType)) ||
(op == Operation::Div && internal::IsBinaryInput<InputType>::VALUE &&
((op == Operation::Div || op == Operation::Fmul) &&
internal::IsBinaryInput<InputType>::VALUE &&
cpp::is_floating_point_v<
typename internal::MakeScalarInput<InputType>::type> &&
cpp::is_floating_point_v<OutputType>) ||
Expand Down
Loading