Skip to content

Commit 2c01403

Browse files
committed
Add F4E2M1FN type: primitive type
1 parent fa539fb commit 2c01403

File tree

11 files changed

+170
-47
lines changed

11 files changed

+170
-47
lines changed

xla/primitive_util.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ bool HasInfinity(PrimitiveType type) {
9393
return false;
9494
}
9595

96+
bool HasNaN(PrimitiveType type) {
97+
if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) {
98+
return FloatingPointTypeSwitch<bool>(
99+
[&](auto constant_type) -> bool {
100+
return std::numeric_limits<
101+
NativeTypeOf<constant_type>>::has_quiet_NaN;
102+
},
103+
type);
104+
}
105+
return false;
106+
}
107+
96108
bool HasNegativeZero(PrimitiveType type) {
97109
if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) {
98110
return FloatingPointTypeSwitch<bool>(

xla/primitive_util.h

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ int ExponentBias(PrimitiveType type);
6969
// Returns whether the type has a value for infinity.
7070
bool HasInfinity(PrimitiveType type);
7171

72+
// Returns whether the type has a value for NaN.
73+
bool HasNaN(PrimitiveType type);
74+
7275
// Returns whether the type has a value for negative zero.
7376
bool HasNegativeZero(PrimitiveType type);
7477

@@ -175,6 +178,11 @@ constexpr PrimitiveType NativeToPrimitiveType<bfloat16>() {
175178
return BF16;
176179
}
177180

181+
template <>
182+
constexpr PrimitiveType NativeToPrimitiveType<tsl::float4_e2m1fn>() {
183+
return F4E2M1FN;
184+
}
185+
178186
template <>
179187
constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e5m2>() {
180188
return F8E5M2;
@@ -314,6 +322,11 @@ struct PrimitiveTypeToNative<BF16> {
314322
using type = bfloat16;
315323
};
316324

325+
template <>
326+
struct PrimitiveTypeToNative<F4E2M1FN> {
327+
using type = tsl::float4_e2m1fn;
328+
};
329+
317330
template <>
318331
struct PrimitiveTypeToNative<F8E5M2> {
319332
using type = tsl::float8_e5m2;
@@ -381,6 +394,8 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) {
381394
primitive_type < PrimitiveType_ARRAYSIZE;
382395
}
383396

397+
constexpr bool IsMXType(PrimitiveType type) { return type == F4E2M1FN; }
398+
384399
constexpr bool IsF8Type(PrimitiveType type) {
385400
return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN ||
386401
type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ ||
@@ -389,7 +404,7 @@ constexpr bool IsF8Type(PrimitiveType type) {
389404

390405
constexpr bool IsFloatingPointType(PrimitiveType type) {
391406
return type == F16 || type == F32 || type == F64 || type == BF16 ||
392-
IsF8Type(type);
407+
IsF8Type(type) || IsMXType(type);
393408
}
394409

395410
constexpr bool IsComplexType(PrimitiveType type) {
@@ -449,6 +464,9 @@ template <typename R, typename F>
449464
constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) {
450465
if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) {
451466
switch (type) {
467+
case F4E2M1FN:
468+
return std::forward<F>(f)(
469+
PrimitiveTypeConstant<PrimitiveType::F4E2M1FN>());
452470
case F8E3M4:
453471
return std::forward<F>(f)(
454472
PrimitiveTypeConstant<PrimitiveType::F8E3M4>());
@@ -553,6 +571,9 @@ inline constexpr int PrimitiveTypeBitWidth() {
553571
if constexpr (primitive_type == PRED) {
554572
return std::numeric_limits<NativeT>::digits;
555573
}
574+
if constexpr (IsMXType(primitive_type)) {
575+
return NativeT::kBits;
576+
}
556577
if constexpr (IsFloatingPointType(primitive_type)) {
557578
return sizeof(NativeT) * std::numeric_limits<uint8_t>::digits;
558579
}
@@ -711,21 +732,33 @@ inline bool CastPreservesValues(PrimitiveType from_type,
711732
return false;
712733
}
713734
// F -> F is safe if the exponent/significand are preserved and `to_type`
714-
// preserves infinities in `from_type.
735+
// preserves infinities/nans/unsigned zero in `from_type`.
715736
if (primitive_util::IsFloatingPointType(from_type) &&
716737
primitive_util::IsFloatingPointType(to_type)) {
717-
return (!primitive_util::HasInfinity(from_type) ||
718-
primitive_util::HasInfinity(to_type)) &&
719-
primitive_util::SignificandWidth(from_type) <=
720-
primitive_util::SignificandWidth(to_type) &&
721-
primitive_util::ExponentWidth(from_type) <=
722-
primitive_util::ExponentWidth(to_type) &&
723-
(primitive_util::UnderflowExponent(from_type) -
724-
primitive_util::SignificandWidth(from_type)) >=
725-
(primitive_util::UnderflowExponent(to_type) -
726-
primitive_util::SignificandWidth(to_type)) &&
727-
primitive_util::OverflowExponent(from_type) <=
728-
primitive_util::OverflowExponent(to_type);
738+
return
739+
// Target mantissa should be large enough.
740+
primitive_util::SignificandWidth(from_type) <=
741+
primitive_util::SignificandWidth(to_type) &&
742+
// Target exponent should be large enough.
743+
primitive_util::ExponentWidth(from_type) <=
744+
primitive_util::ExponentWidth(to_type) &&
745+
// HasInfinity check.
746+
(!primitive_util::HasInfinity(from_type) ||
747+
primitive_util::HasInfinity(to_type)) &&
748+
// HasNaN check.
749+
(!primitive_util::HasNaN(from_type) ||
750+
primitive_util::HasNaN(to_type)) &&
751+
// HasNegativeZero check.
752+
(!primitive_util::HasNegativeZero(from_type) ||
753+
primitive_util::HasNegativeZero(to_type)) &&
754+
// Minimum denormal should be representable by target type.
755+
(primitive_util::UnderflowExponent(from_type) -
756+
primitive_util::SignificandWidth(from_type)) >=
757+
(primitive_util::UnderflowExponent(to_type) -
758+
primitive_util::SignificandWidth(to_type)) &&
759+
// Maximum exponent may be larger with custom bias (e.g. F8E4M3B11FNUZ).
760+
primitive_util::OverflowExponent(from_type) <=
761+
primitive_util::OverflowExponent(to_type);
729762
}
730763
// F -> I is not safe because it drops fractional numbers.
731764
if (!primitive_util::IsIntegralType(from_type)) {

xla/primitive_util_test.cc

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
7575
expecteds[PRED][C64] = true;
7676
expecteds[PRED][BF16] = true;
7777
expecteds[PRED][C128] = true;
78+
expecteds[PRED][F4E2M1FN] = true;
7879
expecteds[PRED][F8E5M2] = true;
7980
expecteds[PRED][F8E4M3] = true;
8081
expecteds[PRED][F8E4M3FN] = true;
@@ -101,6 +102,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
101102
expecteds[S2][C64] = true;
102103
expecteds[S2][BF16] = true;
103104
expecteds[S2][C128] = true;
105+
expecteds[S2][F4E2M1FN] = true;
104106
expecteds[S2][F8E5M2] = true;
105107
expecteds[S2][F8E4M3] = true;
106108
expecteds[S2][F8E4M3FN] = true;
@@ -127,6 +129,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
127129
expecteds[S4][C64] = true;
128130
expecteds[S4][BF16] = true;
129131
expecteds[S4][C128] = true;
132+
expecteds[S4][F4E2M1FN] = false;
130133
expecteds[S4][F8E5M2] = true;
131134
expecteds[S4][F8E4M3] = true;
132135
expecteds[S4][F8E4M3FN] = true;
@@ -153,6 +156,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
153156
expecteds[S8][C64] = true;
154157
expecteds[S8][BF16] = true;
155158
expecteds[S8][C128] = true;
159+
expecteds[S8][F4E2M1FN] = false;
156160
expecteds[S8][F8E5M2] = false;
157161
expecteds[S8][F8E4M3] = false;
158162
expecteds[S8][F8E4M3FN] = false;
@@ -179,6 +183,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
179183
expecteds[S16][C64] = true;
180184
expecteds[S16][BF16] = false;
181185
expecteds[S16][C128] = true;
186+
expecteds[S16][F4E2M1FN] = false;
182187
expecteds[S16][F8E5M2] = false;
183188
expecteds[S16][F8E4M3] = false;
184189
expecteds[S16][F8E4M3FN] = false;
@@ -205,6 +210,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
205210
expecteds[S32][C64] = false;
206211
expecteds[S32][BF16] = false;
207212
expecteds[S32][C128] = true;
213+
expecteds[S32][F4E2M1FN] = false;
208214
expecteds[S32][F8E5M2] = false;
209215
expecteds[S32][F8E4M3] = false;
210216
expecteds[S32][F8E4M3FN] = false;
@@ -231,6 +237,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
231237
expecteds[S64][C64] = false;
232238
expecteds[S64][BF16] = false;
233239
expecteds[S64][C128] = false;
240+
expecteds[S64][F4E2M1FN] = false;
234241
expecteds[S64][F8E5M2] = false;
235242
expecteds[S64][F8E4M3] = false;
236243
expecteds[S64][F8E4M3FN] = false;
@@ -257,8 +264,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
257264
expecteds[U2][C64] = true;
258265
expecteds[U2][BF16] = true;
259266
expecteds[U2][C128] = true;
260-
expecteds[U2][BF16] = true;
261-
expecteds[U2][C128] = true;
267+
expecteds[U2][F4E2M1FN] = true;
262268
expecteds[U2][F8E5M2] = true;
263269
expecteds[U2][F8E4M3] = true;
264270
expecteds[U2][F8E4M3FN] = true;
@@ -285,8 +291,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
285291
expecteds[U4][C64] = true;
286292
expecteds[U4][BF16] = true;
287293
expecteds[U4][C128] = true;
288-
expecteds[U4][BF16] = true;
289-
expecteds[U4][C128] = true;
294+
expecteds[U4][F4E2M1FN] = false;
290295
expecteds[U4][F8E5M2] = false;
291296
expecteds[U4][F8E4M3] = true;
292297
expecteds[U4][F8E4M3FN] = true;
@@ -313,8 +318,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
313318
expecteds[U8][C64] = true;
314319
expecteds[U8][BF16] = true;
315320
expecteds[U8][C128] = true;
316-
expecteds[U8][BF16] = true;
317-
expecteds[U8][C128] = true;
321+
expecteds[U8][F4E2M1FN] = false;
318322
expecteds[U8][F8E5M2] = false;
319323
expecteds[U8][F8E4M3] = false;
320324
expecteds[U8][F8E4M3FN] = false;
@@ -341,6 +345,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
341345
expecteds[U16][C64] = true;
342346
expecteds[U16][BF16] = false;
343347
expecteds[U16][C128] = true;
348+
expecteds[U16][F4E2M1FN] = false;
344349
expecteds[U16][F8E5M2] = false;
345350
expecteds[U16][F8E4M3] = false;
346351
expecteds[U16][F8E4M3FN] = false;
@@ -367,6 +372,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
367372
expecteds[U32][C64] = false;
368373
expecteds[U32][BF16] = false;
369374
expecteds[U32][C128] = true;
375+
expecteds[U32][F4E2M1FN] = false;
370376
expecteds[U32][F8E5M2] = false;
371377
expecteds[U32][F8E4M3] = false;
372378
expecteds[U32][F8E4M3FN] = false;
@@ -393,6 +399,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
393399
expecteds[U64][C64] = false;
394400
expecteds[U64][BF16] = false;
395401
expecteds[U64][C128] = false;
402+
expecteds[U64][F4E2M1FN] = false;
396403
expecteds[U64][F8E5M2] = false;
397404
expecteds[U64][F8E4M3] = false;
398405
expecteds[U64][F8E4M3FN] = false;
@@ -419,6 +426,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
419426
expecteds[F16][C64] = true;
420427
expecteds[F16][BF16] = false;
421428
expecteds[F16][C128] = true;
429+
expecteds[F16][F4E2M1FN] = false;
422430
expecteds[F16][F8E5M2] = false;
423431
expecteds[F16][F8E4M3] = false;
424432
expecteds[F16][F8E4M3FN] = false;
@@ -445,6 +453,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
445453
expecteds[F32][C64] = true;
446454
expecteds[F32][BF16] = false;
447455
expecteds[F32][C128] = true;
456+
expecteds[F32][F4E2M1FN] = false;
448457
expecteds[F32][F8E5M2] = false;
449458
expecteds[F32][F8E4M3] = false;
450459
expecteds[F32][F8E4M3FN] = false;
@@ -471,6 +480,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
471480
expecteds[F64][C64] = false;
472481
expecteds[F64][BF16] = false;
473482
expecteds[F64][C128] = true;
483+
expecteds[F64][F4E2M1FN] = false;
474484
expecteds[F64][F8E5M2] = false;
475485
expecteds[F64][F8E4M3] = false;
476486
expecteds[F64][F8E4M3FN] = false;
@@ -497,6 +507,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
497507
expecteds[C64][C64] = true;
498508
expecteds[C64][BF16] = false;
499509
expecteds[C64][C128] = true;
510+
expecteds[C64][F4E2M1FN] = false;
500511
expecteds[C64][F8E5M2] = false;
501512
expecteds[C64][F8E4M3] = false;
502513
expecteds[C64][F8E4M3FN] = false;
@@ -523,6 +534,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
523534
expecteds[BF16][C64] = true;
524535
expecteds[BF16][BF16] = true;
525536
expecteds[BF16][C128] = true;
537+
expecteds[BF16][F4E2M1FN] = false;
526538
expecteds[BF16][F8E5M2] = false;
527539
expecteds[BF16][F8E4M3] = false;
528540
expecteds[BF16][F8E4M3FN] = false;
@@ -549,13 +561,41 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
549561
expecteds[C128][C64] = false;
550562
expecteds[C128][BF16] = false;
551563
expecteds[C128][C128] = true;
564+
expecteds[C128][F4E2M1FN] = false;
552565
expecteds[C128][F8E5M2] = false;
553566
expecteds[C128][F8E4M3] = false;
554567
expecteds[C128][F8E4M3FN] = false;
555568
expecteds[C128][F8E4M3B11FNUZ] = false;
556569
expecteds[C128][F8E5M2FNUZ] = false;
557570
expecteds[C128][F8E4M3FNUZ] = false;
558571
expecteds[C128][F8E3M4] = false;
572+
expecteds[F4E2M1FN][PRED] = false;
573+
expecteds[F4E2M1FN][S2] = false;
574+
expecteds[F4E2M1FN][S4] = false;
575+
expecteds[F4E2M1FN][S8] = false;
576+
expecteds[F4E2M1FN][S16] = false;
577+
expecteds[F4E2M1FN][S32] = false;
578+
expecteds[F4E2M1FN][S64] = false;
579+
expecteds[F4E2M1FN][U2] = false;
580+
expecteds[F4E2M1FN][U4] = false;
581+
expecteds[F4E2M1FN][U8] = false;
582+
expecteds[F4E2M1FN][U16] = false;
583+
expecteds[F4E2M1FN][U32] = false;
584+
expecteds[F4E2M1FN][U64] = false;
585+
expecteds[F4E2M1FN][F16] = true;
586+
expecteds[F4E2M1FN][F32] = true;
587+
expecteds[F4E2M1FN][F64] = true;
588+
expecteds[F4E2M1FN][C64] = true;
589+
expecteds[F4E2M1FN][BF16] = true;
590+
expecteds[F4E2M1FN][C128] = true;
591+
expecteds[F4E2M1FN][F4E2M1FN] = true;
592+
expecteds[F4E2M1FN][F8E5M2] = true;
593+
expecteds[F4E2M1FN][F8E4M3] = true;
594+
expecteds[F4E2M1FN][F8E4M3FN] = true;
595+
expecteds[F4E2M1FN][F8E4M3B11FNUZ] = false;
596+
expecteds[F4E2M1FN][F8E4M3FNUZ] = false;
597+
expecteds[F4E2M1FN][F8E5M2FNUZ] = false;
598+
expecteds[F4E2M1FN][F8E3M4] = true;
559599
expecteds[F8E5M2][PRED] = false;
560600
expecteds[F8E5M2][S2] = false;
561601
expecteds[F8E5M2][S4] = false;
@@ -575,6 +615,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
575615
expecteds[F8E5M2][C64] = true;
576616
expecteds[F8E5M2][BF16] = true;
577617
expecteds[F8E5M2][C128] = true;
618+
expecteds[F8E5M2][F4E2M1FN] = false;
578619
expecteds[F8E5M2][F8E5M2] = true;
579620
expecteds[F8E5M2][F8E4M3] = false;
580621
expecteds[F8E5M2][F8E4M3FN] = false;
@@ -601,6 +642,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
601642
expecteds[F8E4M3][C64] = true;
602643
expecteds[F8E4M3][BF16] = true;
603644
expecteds[F8E4M3][C128] = true;
645+
expecteds[F8E4M3][F4E2M1FN] = false;
604646
expecteds[F8E4M3][F8E5M2] = false;
605647
expecteds[F8E4M3][F8E5M2FNUZ] = false;
606648
expecteds[F8E4M3][F8E4M3] = true;
@@ -627,6 +669,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
627669
expecteds[F8E4M3FN][C64] = true;
628670
expecteds[F8E4M3FN][BF16] = true;
629671
expecteds[F8E4M3FN][C128] = true;
672+
expecteds[F8E4M3FN][F4E2M1FN] = false;
630673
expecteds[F8E4M3FN][F8E5M2] = false;
631674
expecteds[F8E4M3FN][F8E5M2FNUZ] = false;
632675
expecteds[F8E4M3FN][F8E4M3] = false;
@@ -653,6 +696,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
653696
expecteds[F8E4M3B11FNUZ][C64] = true;
654697
expecteds[F8E4M3B11FNUZ][BF16] = true;
655698
expecteds[F8E4M3B11FNUZ][C128] = true;
699+
expecteds[F8E4M3B11FNUZ][F4E2M1FN] = false;
656700
expecteds[F8E4M3B11FNUZ][F8E5M2] = false;
657701
expecteds[F8E4M3B11FNUZ][F8E4M3] = false;
658702
expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false;
@@ -679,6 +723,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
679723
expecteds[F8E5M2FNUZ][C64] = true;
680724
expecteds[F8E5M2FNUZ][BF16] = true;
681725
expecteds[F8E5M2FNUZ][C128] = true;
726+
expecteds[F8E5M2FNUZ][F4E2M1FN] = false;
682727
expecteds[F8E5M2FNUZ][F8E5M2] = false;
683728
expecteds[F8E5M2FNUZ][F8E4M3] = false;
684729
expecteds[F8E5M2FNUZ][F8E4M3FN] = false;
@@ -705,6 +750,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
705750
expecteds[F8E4M3FNUZ][C64] = true;
706751
expecteds[F8E4M3FNUZ][BF16] = true;
707752
expecteds[F8E4M3FNUZ][C128] = true;
753+
expecteds[F8E4M3FNUZ][F4E2M1FN] = false;
708754
expecteds[F8E4M3FNUZ][F8E5M2] = false;
709755
expecteds[F8E4M3FNUZ][F8E4M3] = false;
710756
expecteds[F8E4M3FNUZ][F8E4M3FN] = false;
@@ -731,6 +777,7 @@ TEST(PrimitiveUtilTest, CastPreservesValues) {
731777
expecteds[F8E3M4][C64] = true;
732778
expecteds[F8E3M4][BF16] = true;
733779
expecteds[F8E3M4][C128] = true;
780+
expecteds[F8E3M4][F4E2M1FN] = false;
734781
expecteds[F8E3M4][F8E5M2] = false;
735782
expecteds[F8E3M4][F8E5M2FNUZ] = false;
736783
expecteds[F8E3M4][F8E4M3] = false;

xla/python/py_values.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ limitations under the License.
3434
#include "absl/strings/str_join.h"
3535
#include "absl/types/span.h"
3636
#include "nanobind/nanobind.h"
37-
#include "nanobind/stl/complex.h" // IWYU pragma: keep
37+
#include "nanobind/stl/complex.h" // IWYU pragma: keep
3838
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
3939
#include "xla/primitive_util.h"
4040
#include "xla/python/ifrt/array.h"

0 commit comments

Comments
 (0)