Skip to content

[APFloat] Add support for f8E4M3 IEEE 754 type #97179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions clang/include/clang/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ class alignas(void *) Stmt {
unsigned : NumExprBits;

static_assert(
llvm::APFloat::S_MaxSemantics < 16,
"Too many Semantics enum values to fit in bitfield of size 4");
llvm::APFloat::S_MaxSemantics < 32,
"Too many Semantics enum values to fit in bitfield of size 5");
LLVM_PREFERRED_TYPE(llvm::APFloat::Semantics)
unsigned Semantics : 4; // Provides semantics for APFloat construction
unsigned Semantics : 5; // Provides semantics for APFloat construction
LLVM_PREFERRED_TYPE(bool)
unsigned IsExact : 1;
};
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_IEEEquad: Out << 'Y'; break;
case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
case APFloat::S_Float8E5M2:
case APFloat::S_Float8E4M3:
case APFloat::S_Float8E4M3FN:
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ struct APFloatBase {
// This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E5M2FNUZ,
// 8-bit floating point number following IEEE-754 conventions with bit
// layout S1E4M3.
S_Float8E4M3,
// 8-bit floating point number mostly following IEEE-754 conventions with
// bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433.
// Unlike IEEE-754 types, there are no infinity values, and NaN is
Expand Down Expand Up @@ -217,6 +220,7 @@ struct APFloatBase {
static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
static const fltSemantics &Float8E5M2() LLVM_READNONE;
static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3() LLVM_READNONE;
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
Expand Down Expand Up @@ -638,6 +642,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
APInt convertFloat8E5M2APFloatToAPInt() const;
APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3APFloatToAPInt() const;
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
Expand All @@ -656,6 +661,7 @@ class IEEEFloat final : public APFloatBase {
void initFromPPCDoubleDoubleAPInt(const APInt &api);
void initFromFloat8E5M2APInt(const APInt &api);
void initFromFloat8E5M2FNUZAPInt(const APInt &api);
void initFromFloat8E4M3APInt(const APInt &api);
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128};
static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8};
static constexpr fltSemantics semFloat8E4M3FN = {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
Expand Down Expand Up @@ -208,6 +209,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E5M2();
case S_Float8E5M2FNUZ:
return Float8E5M2FNUZ();
case S_Float8E4M3:
return Float8E4M3();
case S_Float8E4M3FN:
return Float8E4M3FN();
case S_Float8E4M3FNUZ:
Expand Down Expand Up @@ -246,6 +249,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E5M2;
else if (&Sem == &llvm::APFloat::Float8E5M2FNUZ())
return S_Float8E5M2FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3())
return S_Float8E4M3;
else if (&Sem == &llvm::APFloat::Float8E4M3FN())
return S_Float8E4M3FN;
else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
Expand Down Expand Up @@ -276,6 +281,7 @@ const fltSemantics &APFloatBase::PPCDoubleDouble() {
}
const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; }
const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; }
const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
Expand Down Expand Up @@ -3617,6 +3623,11 @@ APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>();
}

APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E4M3>();
}

APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E4M3FN>();
Expand Down Expand Up @@ -3681,6 +3692,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ)
return convertFloat8E5M2FNUZAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3)
return convertFloat8E4M3APFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN)
return convertFloat8E4M3FNAPFloatToAPInt();

Expand Down Expand Up @@ -3902,6 +3916,10 @@ void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E5M2FNUZ>(api);
}

void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3>(api);
}

void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3FN>(api);
}
Expand Down Expand Up @@ -3951,6 +3969,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E5M2APInt(api);
if (Sem == &semFloat8E5M2FNUZ)
return initFromFloat8E5M2FNUZAPInt(api);
if (Sem == &semFloat8E4M3)
return initFromFloat8E4M3APInt(api);
if (Sem == &semFloat8E4M3FN)
return initFromFloat8E4M3FNAPInt(api);
if (Sem == &semFloat8E4M3FNUZ)
Expand Down
66 changes: 66 additions & 0 deletions llvm/unittests/ADT/APFloatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,8 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E5M2(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E5M2FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E5M2FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E4M3(), false, true, {0, 0}, 1},
{&APFloat::Float8E4M3(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1},
{&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
Expand Down Expand Up @@ -6532,6 +6534,34 @@ TEST(APFloatTest, Float8E5M2ToDouble) {
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E4M3ToDouble) {
APFloat One(APFloat::Float8E4M3(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
APFloat Two(APFloat::Float8E4M3(), "2.0");
EXPECT_EQ(2.0, Two.convertToDouble());
APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false);
EXPECT_EQ(240.0F, PosLargest.convertToDouble());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true);
EXPECT_EQ(-240.0F, NegLargest.convertToDouble());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false);
EXPECT_EQ(0x1.p-6, PosSmallest.convertToDouble());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true);
EXPECT_EQ(-0x1.p-6, NegSmallest.convertToDouble());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false);
EXPECT_TRUE(SmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToDouble());

APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3());
EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true);
EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3());
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E4M3FNToDouble) {
APFloat One(APFloat::Float8E4M3FN(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
Expand Down Expand Up @@ -6846,6 +6876,42 @@ TEST(APFloatTest, Float8E5M2ToFloat) {
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, Float8E4M3ToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3());
APFloat PosZeroToFloat(PosZero.convertToFloat());
EXPECT_TRUE(PosZeroToFloat.isPosZero());
APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3(), true);
APFloat NegZeroToFloat(NegZero.convertToFloat());
EXPECT_TRUE(NegZeroToFloat.isNegZero());

APFloat One(APFloat::Float8E4M3(), "1.0");
EXPECT_EQ(1.0F, One.convertToFloat());
APFloat Two(APFloat::Float8E4M3(), "2.0");
EXPECT_EQ(2.0F, Two.convertToFloat());

APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false);
EXPECT_EQ(240.0F, PosLargest.convertToFloat());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true);
EXPECT_EQ(-240.0F, NegLargest.convertToFloat());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false);
EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true);
EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false);
EXPECT_TRUE(SmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat());

APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3());
EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true);
EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3());
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, Float8E4M3FNToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN());
APFloat PosZeroToFloat(PosZero.convertToFloat());
Expand Down
Loading