Skip to content

Commit 1e145f9

Browse files
sergey-kozubGoogle-ML-Automation
authored andcommitted
PR #19096: Add F4E2M1FN and F8E8M0FNU types
Imported from GitHub PR #19096 This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - openxla/stablehlo#2582 - jax-ml/ml_dtypes#181 - llvm/llvm-project#95392 - llvm/llvm-project#108877 - jax-ml/ml_dtypes#166 - llvm/llvm-project#107127 - llvm/llvm-project#111028 The PR is split into multiple commits just to make the review easier, it is possible that some tests could fail if only some (i.e. not all) of these commits are applied. Copybara import of the project: -- fa539fb by Sergey Kozub <[email protected]>: Add F4E2M1FN type: import mxfloat.h -- 2c01403 by Sergey Kozub <[email protected]>: Add F4E2M1FN type: primitive type -- e919ed5 by Sergey Kozub <[email protected]>: Add F4E2M1FN type: literal support -- ca16839 by Sergey Kozub <[email protected]>: Add F4E2M1FN type: conversion codegen -- eedc079 by Sergey Kozub <[email protected]>: Add F4E2M1FN type: python interface -- 8e0305c by Sergey Kozub <[email protected]>: Add F4E2M1FN type: FFI -- aabe9c6 by Sergey Kozub <[email protected]>: Add F4E2M1FN type: HLO evaluator -- 87da2eb by Sergey Kozub <[email protected]>: Add F4E2M1FN type: add tests -- e0ee48c by Sergey Kozub <[email protected]>: Add F8E8M0FNU type -- be2e457 by Sergey Kozub <[email protected]>: Addressing PR#19096 review comments Merging this change closes #19096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#19096 from openxla:skozub/e2m1 be2e457 PiperOrigin-RevId: 702273510
1 parent 2500111 commit 1e145f9

File tree

79 files changed

+1767
-351
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1767
-351
lines changed

third_party/tsl/tsl/platform/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,7 @@ cc_library(
10661066
deps = [
10671067
"@ml_dtypes//:float8",
10681068
"@ml_dtypes//:intn",
1069+
"@ml_dtypes//:mxfloat",
10691070
],
10701071
)
10711072

third_party/tsl/tsl/platform/ml_dtypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@ limitations under the License.
1818

1919
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
2020
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
21+
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes
2122

2223
namespace tsl {
24+
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
2325
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
2426
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
2527
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
2628
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
2729
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
2830
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
2931
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
32+
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;
3033

3134
using int2 = ::ml_dtypes::int2;
3235
using uint2 = ::ml_dtypes::uint2;

xla/array2d_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
219219
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
220220
}
221221

222+
TEST(Array2dTest, LinspaceF4E2M1FN) {
223+
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);
224+
225+
EXPECT_EQ(arr->n1(), 3);
226+
EXPECT_EQ(arr->n2(), 2);
227+
228+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
229+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
230+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
231+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
232+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
233+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234+
}
235+
236+
TEST(Array2dTest, LinspaceF8E8M0FNU) {
237+
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);
238+
239+
EXPECT_EQ(arr->n1(), 3);
240+
EXPECT_EQ(arr->n2(), 2);
241+
242+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
243+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
244+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
245+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
246+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
247+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
248+
}
249+
222250
TEST(Array2dTest, Stringification) {
223251
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
224252
const std::string expected = R"([[1, 1.5],

xla/comparison_util.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ class Comparison {
193193
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
194194
// Reference:
195195
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
196-
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
197-
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
196+
if constexpr (std::numeric_limits<T>::is_signed) {
197+
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
198+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
199+
} else {
200+
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
201+
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
202+
}
198203
}
199204
}
200205
// Applies the comparison from this Comparison's direction and ordering.

xla/ffi/api/api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
131131
return os << "C128";
132132
case XLA_FFI_DataType_TOKEN:
133133
return os << "TOKEN";
134+
case XLA_FFI_DataType_F4E2M1FN:
135+
return os << "F4E2M1FN";
134136
case XLA_FFI_DataType_F8E5M2:
135137
return os << "F8E5M2";
136138
case XLA_FFI_DataType_F8E3M4:
@@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
145147
return os << "F8E5M2FNUZ";
146148
case XLA_FFI_DataType_F8E4M3FNUZ:
147149
return os << "F8E4M3FNUZ";
150+
case XLA_FFI_DataType_F8E8M0FNU:
151+
return os << "F8E8M0FNU";
148152
}
149153
}
150154

xla/ffi/api/c_api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ typedef enum {
201201
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
202202
XLA_FFI_DataType_F8E5M2FNUZ = 24,
203203
XLA_FFI_DataType_F8E4M3FNUZ = 25,
204+
XLA_FFI_DataType_F4E2M1FN = 30,
205+
XLA_FFI_DataType_F8E8M0FNU = 31,
204206
} XLA_FFI_DataType;
205207
// LINT.ThenChange(ffi_test.cc)
206208

xla/ffi/api/ffi.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ enum class DataType : uint8_t {
7979
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
8080
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
8181
F8E3M4 = XLA_FFI_DataType_F8E3M4,
82+
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
83+
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
8284
};
8385

8486
// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
@@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
106108
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
107109
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
108110
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
111+
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
112+
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;
109113

110114
inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
111115
return os << static_cast<XLA_FFI_DataType>(dtype);
@@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) {
127131
case DataType::F8E5M2FNUZ:
128132
case DataType::F8E4M3FNUZ:
129133
case DataType::F8E3M4:
134+
case DataType::F4E2M1FN:
135+
case DataType::F8E8M0FNU:
130136
return 1;
131137
case DataType::S16:
132138
case DataType::U16:

xla/ffi/api/ffi_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {
129129

130130
EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));
131131

132+
EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN));
132133
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
133134
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
134135
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
@@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) {
137138
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
138139
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
139140
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
141+
EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU));
140142
}
141143

142144
TEST(FfiTest, DataTypeByteWidth) {
@@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {
179181
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128),
180182
ByteWidth(DataType::C128));
181183

184+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN),
185+
ByteWidth(DataType::F4E2M1FN));
182186
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
183187
ByteWidth(DataType::F8E5M2));
184188
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
@@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) {
193197
ByteWidth(DataType::F8E4M3FNUZ));
194198
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
195199
ByteWidth(DataType::F8E3M4));
200+
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU),
201+
ByteWidth(DataType::F8E8M0FNU));
196202
}
197203

198204
TEST(FfiTest, ErrorEnumValue) {

xla/ffi/call_frame.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,15 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
264264
case PrimitiveType::C64:
265265
case PrimitiveType::C128:
266266
case PrimitiveType::TOKEN:
267+
case PrimitiveType::F4E2M1FN:
267268
case PrimitiveType::F8E5M2:
268269
case PrimitiveType::F8E4M3:
269270
case PrimitiveType::F8E4M3FN:
270271
case PrimitiveType::F8E4M3B11FNUZ:
271272
case PrimitiveType::F8E5M2FNUZ:
272273
case PrimitiveType::F8E4M3FNUZ:
273274
case PrimitiveType::F8E3M4:
275+
case PrimitiveType::F8E8M0FNU:
274276
return static_cast<XLA_FFI_DataType>(primitive_type);
275277
default:
276278
DCHECK(false) << "Unsupported primitive type "

xla/fp_util_test.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,76 @@ class FP8E4M3DistanceTest : public ::testing::Test {};
119119
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
120120
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);
121121

122+
TEST(FPDistanceTest, F4E2M1FNDistance) {
123+
// a & b are equal
124+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
125+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)),
126+
0);
127+
128+
// a & b have the same exponents
129+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
130+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)),
131+
1);
132+
133+
// a & b have different exponents
134+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
135+
tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)),
136+
2);
137+
138+
// 1 from 0 in the positive direction
139+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
140+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
141+
tsl::float4_e2m1fn(0)),
142+
1);
143+
144+
// 1 from 0 in the negative direction
145+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
146+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
147+
tsl::float4_e2m1fn(0)),
148+
1);
149+
150+
// a & b have different signs
151+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
152+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
153+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min()),
154+
2);
155+
156+
// 1 non denorm from 0 in the positive direction
157+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
158+
std::numeric_limits<tsl::float4_e2m1fn>::min(),
159+
tsl::float4_e2m1fn(0)),
160+
2);
161+
162+
// 1 non denorm from 0 in the negative direction
163+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
164+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
165+
tsl::float4_e2m1fn(0)),
166+
2);
167+
168+
// a & b have different signs
169+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
170+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
171+
std::numeric_limits<tsl::float4_e2m1fn>::min()),
172+
4);
173+
}
174+
175+
TEST(FPDistanceTest, F8E8M0FNUDistance) {
176+
// a & b are equal
177+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
178+
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)),
179+
0);
180+
181+
// one step apart
182+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
183+
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)),
184+
1);
185+
186+
// two steps apart
187+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
188+
tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)),
189+
2);
190+
}
191+
122192
TEST(FPDistanceTest, F8E3M4Distance) {
123193
// a & b are equal
124194
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),

xla/hlo/builder/lib/math.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) {
184184
case F32:
185185
return Eq(BitcastConvertType(operand, U32),
186186
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
187+
case F4E2M1FN:
187188
case F8E3M4:
188189
case F8E4M3:
189190
case F8E5M2:
@@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
971972
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
972973
PrimitiveType a_x_type = a_shape.element_type();
973974
bool needs_upcast = false;
974-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
975-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
975+
for (PrimitiveType type :
976+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
977+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
976978
if (a_shape.element_type() == type) {
977979
needs_upcast = true;
978980
break;
@@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
10241026
}
10251027
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
10261028
bool needs_upcast = false;
1027-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
1028-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
1029+
for (PrimitiveType type :
1030+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
1031+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
10291032
if (a_shape.element_type() == type) {
10301033
needs_upcast = true;
10311034
break;

xla/hlo/builder/lib/math_test.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
9595
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
9696

9797
bool has_inf = std::numeric_limits<T>::has_infinity;
98+
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
99+
bool has_finite = !has_inf && !has_nan;
100+
bool has_nan_only = !has_inf && has_nan;
101+
98102
auto expected = LiteralUtil::MakeTupleOwned(
99-
LiteralUtil::CreateR1<bool>(
100-
{true, true, true, true, true, false, false, false, false}),
103+
LiteralUtil::CreateR1<bool>({true, true, true, true, true, has_finite,
104+
has_finite, has_finite, has_finite}),
101105
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
102106
has_inf, false, false}),
103107
LiteralUtil::CreateR1<bool>(
104108
{false, false, false, false, false, has_inf, false, false, false}),
105109
LiteralUtil::CreateR1<bool>(
106110
{false, false, false, false, false, false, has_inf, false, false}),
107111
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
108-
!has_inf, !has_inf, true, true}));
112+
has_nan_only, has_nan_only, has_nan,
113+
has_nan}));
109114
ComputeAndCompareLiteral(&b, expected, {});
110115
}
111116

@@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
118123
LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
119124
&b));
120125

126+
bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
121127
ComputeAndCompareLiteral(
122128
&b,
123129
LiteralUtil::CreateR1<bool>(
124-
{has_negative_zero_v<T>, false, false, false, false, false, false}),
130+
{has_negative_zero_v<T>, false, false, false, false, false, is_mx}),
125131
{}, error_spec_);
126132
}
127133

@@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
136142
// For good measure, we also check pow with an exponent other than 0.5.
137143
void TestSqrtPowInequivalence() {
138144
SetFastMathDisabled(true);
145+
if (std::is_same_v<T, tsl::float4_e2m1fn>) {
146+
GTEST_SKIP() << "Skipping due to low precision";
147+
}
139148

140149
// Tests disable constant folding by default, but this test needs it
141150
// enabled, otherwise we don't tickle the bug we're trying to catch.
@@ -181,19 +190,24 @@ class MathTypedTest : public MathTest {
181190
&b);
182191
Erf(x);
183192

184-
bool has_inf = std::numeric_limits<T>::has_infinity;
185-
std::vector<T> expected = {
186-
has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)};
193+
bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
194+
std::numeric_limits<T>::has_quiet_NaN;
195+
std::vector<T> expected = {inf_as_nan ? nan : T(-1),
196+
inf_as_nan ? nan : T(1),
197+
T(-0),
198+
T(0),
199+
T(-1),
200+
T(1)};
187201

188202
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
189203
}
190204
};
191205

192206
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
193207
using TestTypes =
194-
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fnuz,
195-
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2,
196-
tsl::float8_e5m2fnuz,
208+
::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3,
209+
tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz,
210+
tsl::float8_e5m2, tsl::float8_e5m2fnuz,
197211
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
198212
Eigen::half,
199213
#endif

xla/hlo/evaluator/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"hlo_evaluator_typed_visitor_int4.cc",
3737
"hlo_evaluator_typed_visitor_int64.cc",
3838
"hlo_evaluator_typed_visitor_int8.cc",
39+
"hlo_evaluator_typed_visitor_mxfloat.cc",
3940
"hlo_evaluator_typed_visitor_uint16.cc",
4041
"hlo_evaluator_typed_visitor_uint32.cc",
4142
"hlo_evaluator_typed_visitor_uint64.cc",

xla/hlo/evaluator/hlo_evaluator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3722,7 +3722,7 @@ absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
37223722
const Shape& result_shape) {
37233723
std::function<ResultT(Fp, Uint)> stochastic_convert_op =
37243724
[](Fp operand, Uint random) -> ResultT {
3725-
bool is_negative = static_cast<bool>(Eigen::numext::signbit(operand));
3725+
bool is_negative = static_cast<bool>(SignAndMagnitude(operand).first);
37263726
if (Eigen::numext::isinf(operand)) {
37273727
return is_negative ? std::numeric_limits<ResultT>::min()
37283728
: std::numeric_limits<ResultT>::max();

0 commit comments

Comments
 (0)