Skip to content

[MLIR] Add f8E3M4 IEEE 754 type #101230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 2, 2024
Merged

Conversation

apivovarov
Copy link
Member

This PR adds f8E3M4 type to mlir.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type
  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type

@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Alexander Pivovarov (apivovarov)

Changes

This PR adds f8E3M4 type to mlir.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type
  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type

Patch is 20.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101230.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+19)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+21)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+2-2)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2212087b9898f..d698bf4764568 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E3M4 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E3M4 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
+
+/// Creates an f8E3M4 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c4d329fbf0d8..b5962f3783924 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -66,6 +66,7 @@ class Builder {
   FloatType getFloat8E5M2FNUZType();
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
+  FloatType getFloat8E3M4Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4250be90ba7fb..d12522ba55c96 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -66,6 +66,7 @@ class FloatType : public Type {
   static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E3M4(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -411,10 +412,11 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<
-      Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
-      Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
-      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
+                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
+                   Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -441,6 +443,10 @@ inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
   return Float8E4M3B11FNUZType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
+  return Float8E3M4Type::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 176a167a3ca31..365edcf68d8b9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -213,6 +213,25 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B1
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E3M4Type
+
+def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
+  let summary = "8-bit floating point with 3 bits exponent and 4 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M4
+      * exponent bias: 3
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa values of
+        {0,1}⁴ except S.111.0000
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index c23d2d87e080f..5b6ec167fa242 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -342,6 +342,8 @@ def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ t
                  BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+             BuildableType<"$_builder.getFloat8E3M4Type()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a32de33114e40..60dc8fee0f4a9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -131,6 +131,7 @@ class Type {
   bool isFloat8E5M2FNUZ() const;
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
+  bool isFloat8E3M4() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index eb3154c6da42e..4c1c1c21031c8 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -100,6 +100,7 @@ TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
+TOK_KEYWORD(f8E3M4)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 467b5f0844ab3..542eaeefe57f1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -45,6 +45,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E5M2FNUZ:
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
+  case Token::kw_f8E3M4:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -320,6 +321,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3B11FNUZ:
     consumeToken(Token::kw_f8E4M3B11FNUZ);
     return builder.getFloat8E4M3B11FNUZType();
+  case Token::kw_f8E3M4:
+    consumeToken(Token::kw_f8E3M4);
+    return builder.getFloat8E3M4Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5e0aebc03e2c1..c3d42c0ef8e3c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
   }
 };
 
+/// Floating Point Type subclass - Float8E3M4Type.
+class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E3M4TypeGet(context->get());
+          return PyFloat8E3M4Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e3m4 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3FNUZType::bind(m);
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
+  PyFloat8E3M4Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d507027357c26..2aa2e922f2abc 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
+  return wrap(Float8E3M4Type::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E3M4(MlirType type) {
+  return unwrap(type).isFloat8E3M4();
+}
+
+MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..784deaac5ee65 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,7 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a362c8500aa5b..51f229ef937c4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -60,6 +60,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E4M3FN", b.getFloat8E4M3FNType())
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("f8E3M4", b.getFloat8E3M4Type())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5b1291afce2b..02acc8c3f4659 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2581,6 +2581,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
+      .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d0eb2d8fbae9d..e3d6d71fb61df 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,10 @@ FloatType Builder::getFloat8E4M3B11FNUZType() {
   return FloatType::getFloat8E4M3B11FNUZ(context);
 }
 
+FloatType Builder::getFloat8E3M4Type() {
+  return FloatType::getFloat8E3M4(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faa944937e007..a3f5ece8c1736 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,8 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 
 unsigned FloatType::getWidth() {
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
-          *this))
+                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
+                Float8E3M4Type>(*this))
     return 8;
   if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
@@ -118,6 +118,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3FNUZ();
   if (llvm::isa<Float8E4M3B11FNUZType>(*this))
     return APFloat::Float8E4M3B11FNUZ();
+  if (llvm::isa<Float8E3M4Type>(*this))
+    return APFloat::Float8E3M4();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 12336701c9ca0..5c93747438ecd 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -227,6 +227,7 @@ class MLIRContextImpl {
   Float8E5M2FNUZType f8E5M2FNUZTy;
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
+  Float8E3M4Type f8E3M4Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -318,6 +319,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
+  impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1029,6 +1031,9 @@ Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
 Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3B11FNUZTy;
 }
+Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
+  return context->getImpl().f8E3M4Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e8cd28bf9e85d..2bc26388b6218 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -46,6 +46,7 @@ bool Type::isFloat8E4M3FNUZ() const {
 bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
+bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 224e77a3f46be..e3599d3c84ffe 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
     "Float8E4M3FNUZType",
@@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float8E3M4Type(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float8E3M4Type:
+        """
+        Create a float8_e3m4 type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fde9909a8f9d6..fe7c3e25d1690 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
     Float8E4M3Type,
@@ -72,6 +73,7 @@ def ui(width):
 f8E4M3 = lambda: Float8E4M3Type.get()
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
+f8E3M4 = lambda: Float8E3M4Type.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 362e98134ee4a..ac0aec113add1 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -60,6 +60,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
     float_attr = 2. : f8E4M3B11FNUZ
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E3M4
+    float_attr = 2. : f8E3M4
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..82256f753abdd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f8E3M4_global_as_i8 = internal global i8 56
+llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
+
 // CHECK: @f8E4M3_global_as_i8 = internal global i8 60
 llvm.mlir.global internal @f8E4M3_global_as_i8(1.5 : f8E4M3) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 3178f58cf2e74..2161f110ac31e 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
@@ -231,6 +233,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f8E3M4
+        print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
         print("float:", Float8E4M3Type.get())
         # CHECK: float: f8E4M3FN
@@ -605,6 +609,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
             (Float8E5M2Type, Float8E5M2Type.get()),
@@ -629,6 +634,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
         # CHECK: Float8E5M2Type(f8E5M2)
@@ -707,6 +713,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float8E3M4Type
+        # CHECK: Float8E3M4Type(f8E3M4)
+        print_downcasted(Float8E3M4Type.get())
         # CHECK: Float8E4M3B11FNUZType
         # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
         print_downcasted(Float8E4M3B11FNUZType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index ed0ee431fd7d8..e7c526842439b 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -56,6 +56,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
+    "mlir::Float8E3M4Type": '"f8E3M4"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index a657874f894b7..b5926d75da4f2 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -230,8 +230,8 @@ const common = {
   integer_type : $ =>
       token...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir-ods

Author: Alexander Pivovarov (apivovarov)

Changes

This PR adds f8E3M4 type to mlir.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type
  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type

Patch is 20.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101230.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+19)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+21)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+2-2)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2212087b9898f..d698bf4764568 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E3M4 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E3M4 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
+
+/// Creates an f8E3M4 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c4d329fbf0d8..b5962f3783924 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -66,6 +66,7 @@ class Builder {
   FloatType getFloat8E5M2FNUZType();
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
+  FloatType getFloat8E3M4Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4250be90ba7fb..d12522ba55c96 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -66,6 +66,7 @@ class FloatType : public Type {
   static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E3M4(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -411,10 +412,11 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<
-      Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
-      Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
-      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
+                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
+                   Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -441,6 +443,10 @@ inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
   return Float8E4M3B11FNUZType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
+  return Float8E3M4Type::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 176a167a3ca31..365edcf68d8b9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -213,6 +213,25 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B1
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E3M4Type
+
+def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
+  let summary = "8-bit floating point with 3 bits exponent and 4 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M4
+      * exponent bias: 3
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa values of
+        {0,1}⁴ except S.111.0000
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index c23d2d87e080f..5b6ec167fa242 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -342,6 +342,8 @@ def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ t
                  BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+             BuildableType<"$_builder.getFloat8E3M4Type()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a32de33114e40..60dc8fee0f4a9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -131,6 +131,7 @@ class Type {
   bool isFloat8E5M2FNUZ() const;
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
+  bool isFloat8E3M4() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index eb3154c6da42e..4c1c1c21031c8 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -100,6 +100,7 @@ TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
+TOK_KEYWORD(f8E3M4)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 467b5f0844ab3..542eaeefe57f1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -45,6 +45,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E5M2FNUZ:
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
+  case Token::kw_f8E3M4:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -320,6 +321,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3B11FNUZ:
     consumeToken(Token::kw_f8E4M3B11FNUZ);
     return builder.getFloat8E4M3B11FNUZType();
+  case Token::kw_f8E3M4:
+    consumeToken(Token::kw_f8E3M4);
+    return builder.getFloat8E3M4Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5e0aebc03e2c1..c3d42c0ef8e3c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
   }
 };
 
+/// Floating Point Type subclass - Float8E3M4Type.
+class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E3M4TypeGet(context->get());
+          return PyFloat8E3M4Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e3m4 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3FNUZType::bind(m);
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
+  PyFloat8E3M4Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d507027357c26..2aa2e922f2abc 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
+  return wrap(Float8E3M4Type::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E3M4(MlirType type) {
+  return unwrap(type).isFloat8E3M4();
+}
+
+MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..784deaac5ee65 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,7 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a362c8500aa5b..51f229ef937c4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -60,6 +60,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E4M3FN", b.getFloat8E4M3FNType())
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("f8E3M4", b.getFloat8E3M4Type())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5b1291afce2b..02acc8c3f4659 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2581,6 +2581,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
+      .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d0eb2d8fbae9d..e3d6d71fb61df 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,10 @@ FloatType Builder::getFloat8E4M3B11FNUZType() {
   return FloatType::getFloat8E4M3B11FNUZ(context);
 }
 
+FloatType Builder::getFloat8E3M4Type() {
+  return FloatType::getFloat8E3M4(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faa944937e007..a3f5ece8c1736 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,8 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 
 unsigned FloatType::getWidth() {
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
-          *this))
+                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
+                Float8E3M4Type>(*this))
     return 8;
   if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
@@ -118,6 +118,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3FNUZ();
   if (llvm::isa<Float8E4M3B11FNUZType>(*this))
     return APFloat::Float8E4M3B11FNUZ();
+  if (llvm::isa<Float8E3M4Type>(*this))
+    return APFloat::Float8E3M4();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 12336701c9ca0..5c93747438ecd 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -227,6 +227,7 @@ class MLIRContextImpl {
   Float8E5M2FNUZType f8E5M2FNUZTy;
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
+  Float8E3M4Type f8E3M4Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -318,6 +319,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
+  impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1029,6 +1031,9 @@ Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
 Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3B11FNUZTy;
 }
+Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
+  return context->getImpl().f8E3M4Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e8cd28bf9e85d..2bc26388b6218 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -46,6 +46,7 @@ bool Type::isFloat8E4M3FNUZ() const {
 bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
+bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 224e77a3f46be..e3599d3c84ffe 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
     "Float8E4M3FNUZType",
@@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float8E3M4Type(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float8E3M4Type:
+        """
+        Create a float8_e3m4 type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fde9909a8f9d6..fe7c3e25d1690 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
     Float8E4M3Type,
@@ -72,6 +73,7 @@ def ui(width):
 f8E4M3 = lambda: Float8E4M3Type.get()
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
+f8E3M4 = lambda: Float8E3M4Type.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 362e98134ee4a..ac0aec113add1 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -60,6 +60,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
     float_attr = 2. : f8E4M3B11FNUZ
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E3M4
+    float_attr = 2. : f8E3M4
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..82256f753abdd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f8E3M4_global_as_i8 = internal global i8 56
+llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
+
 // CHECK: @f8E4M3_global_as_i8 = internal global i8 60
 llvm.mlir.global internal @f8E4M3_global_as_i8(1.5 : f8E4M3) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 3178f58cf2e74..2161f110ac31e 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
@@ -231,6 +233,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f8E3M4
+        print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
         print("float:", Float8E4M3Type.get())
         # CHECK: float: f8E4M3FN
@@ -605,6 +609,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
             (Float8E5M2Type, Float8E5M2Type.get()),
@@ -629,6 +634,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
         # CHECK: Float8E5M2Type(f8E5M2)
@@ -707,6 +713,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float8E3M4Type
+        # CHECK: Float8E3M4Type(f8E3M4)
+        print_downcasted(Float8E3M4Type.get())
         # CHECK: Float8E4M3B11FNUZType
         # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
         print_downcasted(Float8E4M3B11FNUZType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index ed0ee431fd7d8..e7c526842439b 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -56,6 +56,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
+    "mlir::Float8E3M4Type": '"f8E3M4"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index a657874f894b7..b5926d75da4f2 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -230,8 +230,8 @@ const common = {
   integer_type : $ =>
       token...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir-core

Author: Alexander Pivovarov (apivovarov)

Changes

This PR adds f8E3M4 type to mlir.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type
  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type

Patch is 20.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101230.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+19)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+21)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+2-2)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2212087b9898f..d698bf4764568 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E3M4 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E3M4 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
+
+/// Creates an f8E3M4 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c4d329fbf0d8..b5962f3783924 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -66,6 +66,7 @@ class Builder {
   FloatType getFloat8E5M2FNUZType();
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
+  FloatType getFloat8E3M4Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4250be90ba7fb..d12522ba55c96 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -66,6 +66,7 @@ class FloatType : public Type {
   static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E3M4(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -411,10 +412,11 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<
-      Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
-      Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
-      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
+                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
+                   Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -441,6 +443,10 @@ inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
   return Float8E4M3B11FNUZType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
+  return Float8E3M4Type::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 176a167a3ca31..365edcf68d8b9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -213,6 +213,25 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B1
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E3M4Type
+
+def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
+  let summary = "8-bit floating point with 3 bits exponent and 4 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M4
+      * exponent bias: 3
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa values of
+        {0,1}⁴ except S.111.0000
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index c23d2d87e080f..5b6ec167fa242 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -342,6 +342,8 @@ def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ t
                  BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+             BuildableType<"$_builder.getFloat8E3M4Type()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a32de33114e40..60dc8fee0f4a9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -131,6 +131,7 @@ class Type {
   bool isFloat8E5M2FNUZ() const;
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
+  bool isFloat8E3M4() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index eb3154c6da42e..4c1c1c21031c8 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -100,6 +100,7 @@ TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
+TOK_KEYWORD(f8E3M4)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 467b5f0844ab3..542eaeefe57f1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -45,6 +45,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E5M2FNUZ:
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
+  case Token::kw_f8E3M4:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -320,6 +321,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3B11FNUZ:
     consumeToken(Token::kw_f8E4M3B11FNUZ);
     return builder.getFloat8E4M3B11FNUZType();
+  case Token::kw_f8E3M4:
+    consumeToken(Token::kw_f8E3M4);
+    return builder.getFloat8E3M4Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5e0aebc03e2c1..c3d42c0ef8e3c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
   }
 };
 
+/// Floating Point Type subclass - Float8E3M4Type.
+class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E3M4TypeGet(context->get());
+          return PyFloat8E3M4Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e3m4 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3FNUZType::bind(m);
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
+  PyFloat8E3M4Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d507027357c26..2aa2e922f2abc 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
+  return wrap(Float8E3M4Type::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E3M4(MlirType type) {
+  return unwrap(type).isFloat8E3M4();
+}
+
+MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..784deaac5ee65 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,7 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a362c8500aa5b..51f229ef937c4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -60,6 +60,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E4M3FN", b.getFloat8E4M3FNType())
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("f8E3M4", b.getFloat8E3M4Type())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5b1291afce2b..02acc8c3f4659 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2581,6 +2581,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
+      .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d0eb2d8fbae9d..e3d6d71fb61df 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,10 @@ FloatType Builder::getFloat8E4M3B11FNUZType() {
   return FloatType::getFloat8E4M3B11FNUZ(context);
 }
 
+FloatType Builder::getFloat8E3M4Type() {
+  return FloatType::getFloat8E3M4(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faa944937e007..a3f5ece8c1736 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,8 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 
 unsigned FloatType::getWidth() {
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
-          *this))
+                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
+                Float8E3M4Type>(*this))
     return 8;
   if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
@@ -118,6 +118,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3FNUZ();
   if (llvm::isa<Float8E4M3B11FNUZType>(*this))
     return APFloat::Float8E4M3B11FNUZ();
+  if (llvm::isa<Float8E3M4Type>(*this))
+    return APFloat::Float8E3M4();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 12336701c9ca0..5c93747438ecd 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -227,6 +227,7 @@ class MLIRContextImpl {
   Float8E5M2FNUZType f8E5M2FNUZTy;
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
+  Float8E3M4Type f8E3M4Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -318,6 +319,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
+  impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1029,6 +1031,9 @@ Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
 Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3B11FNUZTy;
 }
+Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
+  return context->getImpl().f8E3M4Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e8cd28bf9e85d..2bc26388b6218 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -46,6 +46,7 @@ bool Type::isFloat8E4M3FNUZ() const {
 bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
+bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 224e77a3f46be..e3599d3c84ffe 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
     "Float8E4M3FNUZType",
@@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float8E3M4Type(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float8E3M4Type:
+        """
+        Create a float8_e3m4 type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fde9909a8f9d6..fe7c3e25d1690 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
     Float8E4M3Type,
@@ -72,6 +73,7 @@ def ui(width):
 f8E4M3 = lambda: Float8E4M3Type.get()
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
+f8E3M4 = lambda: Float8E3M4Type.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 362e98134ee4a..ac0aec113add1 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -60,6 +60,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
     float_attr = 2. : f8E4M3B11FNUZ
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E3M4
+    float_attr = 2. : f8E3M4
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..82256f753abdd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f8E3M4_global_as_i8 = internal global i8 56
+llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
+
 // CHECK: @f8E4M3_global_as_i8 = internal global i8 60
 llvm.mlir.global internal @f8E4M3_global_as_i8(1.5 : f8E4M3) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 3178f58cf2e74..2161f110ac31e 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
@@ -231,6 +233,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f8E3M4
+        print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
         print("float:", Float8E4M3Type.get())
         # CHECK: float: f8E4M3FN
@@ -605,6 +609,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
             (Float8E5M2Type, Float8E5M2Type.get()),
@@ -629,6 +634,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
         # CHECK: Float8E5M2Type(f8E5M2)
@@ -707,6 +713,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float8E3M4Type
+        # CHECK: Float8E3M4Type(f8E3M4)
+        print_downcasted(Float8E3M4Type.get())
         # CHECK: Float8E4M3B11FNUZType
         # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
         print_downcasted(Float8E4M3B11FNUZType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index ed0ee431fd7d8..e7c526842439b 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -56,6 +56,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
+    "mlir::Float8E3M4Type": '"f8E3M4"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index a657874f894b7..b5926d75da4f2 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -230,8 +230,8 @@ const common = {
   integer_type : $ =>
       token...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Alexander Pivovarov (apivovarov)

Changes

This PR adds f8E3M4 type to mlir.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type
  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type

Patch is 20.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101230.diff

24 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+10)
  • (modified) mlir/include/mlir/IR/Builders.h (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+19)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+2)
  • (modified) mlir/include/mlir/IR/Types.h (+1)
  • (modified) mlir/lib/AsmParser/TokenKinds.def (+1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+4)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+21)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+12)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+1)
  • (modified) mlir/lib/IR/Builders.cpp (+4)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+4-2)
  • (modified) mlir/lib/IR/MLIRContext.cpp (+5)
  • (modified) mlir/lib/IR/Types.cpp (+1)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+14)
  • (modified) mlir/python/mlir/extras/types.py (+2)
  • (modified) mlir/test/IR/attribute.mlir (+4)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3)
  • (modified) mlir/test/python/ir/builtin_types.py (+9)
  • (modified) mlir/utils/lldb-scripts/mlirDataFormatters.py (+1)
  • (modified) mlir/utils/tree-sitter-mlir/grammar.js (+2-2)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 2212087b9898f..d698bf4764568 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
 
+/// Returns the typeID of an Float8E3M4 type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
+
+/// Checks whether the given type is an f8E3M4 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
+
+/// Creates an f8E3M4 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
+
 /// Returns the typeID of an BFloat16 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c4d329fbf0d8..b5962f3783924 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -66,6 +66,7 @@ class Builder {
   FloatType getFloat8E5M2FNUZType();
   FloatType getFloat8E4M3FNUZType();
   FloatType getFloat8E4M3B11FNUZType();
+  FloatType getFloat8E3M4Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getTF32Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 4250be90ba7fb..d12522ba55c96 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -66,6 +66,7 @@ class FloatType : public Type {
   static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
   static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+  static FloatType getFloat8E3M4(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -411,10 +412,11 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
 }
 
 inline bool FloatType::classof(Type type) {
-  return llvm::isa<
-      Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
-      Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
-      FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
+  return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
+                   Float16Type, FloatTF32Type, Float32Type, Float64Type,
+                   Float80Type, Float128Type>(type);
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -441,6 +443,10 @@ inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
   return Float8E4M3B11FNUZType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
+  return Float8E3M4Type::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 176a167a3ca31..365edcf68d8b9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -213,6 +213,25 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B1
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E3M4Type
+
+def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
+  let summary = "8-bit floating point with 3 bits exponent and 4 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E3M4
+      * exponent bias: 3
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa values of
+        {0,1}⁴ except S.111.0000
+      * denormals when exponent is 0
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index c23d2d87e080f..5b6ec167fa242 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -342,6 +342,8 @@ def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ t
                  BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
 def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
                  BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
+def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
+             BuildableType<"$_builder.getFloat8E3M4Type()">;
 
 def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
                       "complex-type", "::mlir::ComplexType">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a32de33114e40..60dc8fee0f4a9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -131,6 +131,7 @@ class Type {
   bool isFloat8E5M2FNUZ() const;
   bool isFloat8E4M3FNUZ() const;
   bool isFloat8E4M3B11FNUZ() const;
+  bool isFloat8E3M4() const;
   bool isBF16() const;
   bool isF16() const;
   bool isTF32() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index eb3154c6da42e..4c1c1c21031c8 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -100,6 +100,7 @@ TOK_KEYWORD(f8E4M3FN)
 TOK_KEYWORD(f8E5M2FNUZ)
 TOK_KEYWORD(f8E4M3FNUZ)
 TOK_KEYWORD(f8E4M3B11FNUZ)
+TOK_KEYWORD(f8E3M4)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 467b5f0844ab3..542eaeefe57f1 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -45,6 +45,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
   case Token::kw_f8E5M2FNUZ:
   case Token::kw_f8E4M3FNUZ:
   case Token::kw_f8E4M3B11FNUZ:
+  case Token::kw_f8E3M4:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_tf32:
@@ -320,6 +321,9 @@ Type Parser::parseNonFunctionType() {
   case Token::kw_f8E4M3B11FNUZ:
     consumeToken(Token::kw_f8E4M3B11FNUZ);
     return builder.getFloat8E4M3B11FNUZType();
+  case Token::kw_f8E3M4:
+    consumeToken(Token::kw_f8E3M4);
+    return builder.getFloat8E3M4Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 5e0aebc03e2c1..c3d42c0ef8e3c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
   }
 };
 
+/// Floating Point Type subclass - Float8E3M4Type.
+class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E3M4TypeGet(context->get());
+          return PyFloat8E3M4Type(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a float8_e3m4 type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
   PyFloat8E4M3FNUZType::bind(m);
   PyFloat8E4M3B11FNUZType::bind(m);
   PyFloat8E5M2FNUZType::bind(m);
+  PyFloat8E3M4Type::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyTF32Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index d507027357c26..2aa2e922f2abc 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
   return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
 }
 
+MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
+  return wrap(Float8E3M4Type::getTypeID());
+}
+
+bool mlirTypeIsAFloat8E3M4(MlirType type) {
+  return unwrap(type).isFloat8E3M4();
+}
+
+MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
+}
+
 MlirTypeID mlirBFloat16TypeGetTypeID() {
   return wrap(BFloat16Type::getTypeID());
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..784deaac5ee65 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -249,7 +249,7 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
-      type.isFloat8E4M3B11FNUZ())
+      type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4())
     return IntegerType::get(&getContext(), type.getWidth());
   return type;
 }
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a362c8500aa5b..51f229ef937c4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -60,6 +60,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
       .Case("f8E4M3FN", b.getFloat8E4M3FNType())
       .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
       .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("f8E3M4", b.getFloat8E3M4Type())
       .Case("bf16", b.getBF16Type())
       .Case("f16", b.getF16Type())
       .Case("f32", b.getF32Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5b1291afce2b..02acc8c3f4659 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2581,6 +2581,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
       .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
       .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
+      .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d0eb2d8fbae9d..e3d6d71fb61df 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -58,6 +58,10 @@ FloatType Builder::getFloat8E4M3B11FNUZType() {
   return FloatType::getFloat8E4M3B11FNUZ(context);
 }
 
+FloatType Builder::getFloat8E3M4Type() {
+  return FloatType::getFloat8E3M4(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index faa944937e007..a3f5ece8c1736 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,8 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 
 unsigned FloatType::getWidth() {
   if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
-          *this))
+                Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
+                Float8E3M4Type>(*this))
     return 8;
   if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
@@ -118,6 +118,8 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
     return APFloat::Float8E4M3FNUZ();
   if (llvm::isa<Float8E4M3B11FNUZType>(*this))
     return APFloat::Float8E4M3B11FNUZ();
+  if (llvm::isa<Float8E3M4Type>(*this))
+    return APFloat::Float8E3M4();
   if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
   if (llvm::isa<Float16Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 12336701c9ca0..5c93747438ecd 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -227,6 +227,7 @@ class MLIRContextImpl {
   Float8E5M2FNUZType f8E5M2FNUZTy;
   Float8E4M3FNUZType f8E4M3FNUZTy;
   Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
+  Float8E3M4Type f8E3M4Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   FloatTF32Type tf32Ty;
@@ -318,6 +319,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
   impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
   impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
   impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
+  impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
@@ -1029,6 +1031,9 @@ Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
 Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3B11FNUZTy;
 }
+Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
+  return context->getImpl().f8E3M4Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index e8cd28bf9e85d..2bc26388b6218 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -46,6 +46,7 @@ bool Type::isFloat8E4M3FNUZ() const {
 bool Type::isFloat8E4M3B11FNUZ() const {
   return llvm::isa<Float8E4M3B11FNUZType>(*this);
 }
+bool Type::isFloat8E3M4() const { return llvm::isa<Float8E3M4Type>(*this); }
 bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
 bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
 bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 224e77a3f46be..e3599d3c84ffe 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
     "F32Type",
     "F64Type",
     "FlatSymbolRefAttr",
+    "Float8E3M4Type",
     "Float8E4M3B11FNUZType",
     "Float8E4M3FNType",
     "Float8E4M3FNUZType",
@@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
+class Float8E3M4Type(FloatType):
+    static_typeid: ClassVar[TypeID]
+    @staticmethod
+    def get(context: Optional[Context] = None) -> Float8E3M4Type:
+        """
+        Create a float8_e3m4 type.
+        """
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def typeid(self) -> TypeID: ...
+
 class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]
     @staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index fde9909a8f9d6..fe7c3e25d1690 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
     F16Type,
     F32Type,
     F64Type,
+    Float8E3M4Type,
     Float8E4M3B11FNUZType,
     Float8E4M3FNType,
     Float8E4M3Type,
@@ -72,6 +73,7 @@ def ui(width):
 f8E4M3 = lambda: Float8E4M3Type.get()
 f8E4M3FN = lambda: Float8E4M3FNType.get()
 f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
+f8E3M4 = lambda: Float8E3M4Type.get()
 
 none = lambda: NoneType.get()
 
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 362e98134ee4a..ac0aec113add1 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -60,6 +60,10 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
     float_attr = 2. : f8E4M3B11FNUZ
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E3M4
+    float_attr = 2. : f8E3M4
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index db54d131299c6..82256f753abdd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
 // CHECK: @int_global_undef = internal global i64 undef
 llvm.mlir.global internal @int_global_undef() : i64
 
+// CHECK: @f8E3M4_global_as_i8 = internal global i8 56
+llvm.mlir.global internal @f8E3M4_global_as_i8(1.5 : f8E3M4) : i8
+
 // CHECK: @f8E4M3_global_as_i8 = internal global i8 60
 llvm.mlir.global internal @f8E4M3_global_as_i8(1.5 : f8E4M3) : i8
 
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 3178f58cf2e74..2161f110ac31e 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
 def testFloatTypeSubclasses():
     ctx = Context()
     # CHECK: True
+    print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
+    # CHECK: True
     print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
     # CHECK: True
     print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
@@ -231,6 +233,8 @@ def testIndexType():
 @run
 def testFloatType():
     with Context():
+        # CHECK: float: f8E3M4
+        print("float:", Float8E3M4Type.get())
         # CHECK: float: f8E4M3
         print("float:", Float8E4M3Type.get())
         # CHECK: float: f8E4M3FN
@@ -605,6 +609,7 @@ def testTypeIDs():
         types = [
             (IntegerType, IntegerType.get_signless(16)),
             (IndexType, IndexType.get()),
+            (Float8E3M4Type, Float8E3M4Type.get()),
             (Float8E4M3Type, Float8E4M3Type.get()),
             (Float8E4M3FNType, Float8E4M3FNType.get()),
             (Float8E5M2Type, Float8E5M2Type.get()),
@@ -629,6 +634,7 @@ def testTypeIDs():
 
         # CHECK: IntegerType(i16)
         # CHECK: IndexType(index)
+        # CHECK: Float8E3M4Type(f8E3M4)
         # CHECK: Float8E4M3Type(f8E4M3)
         # CHECK: Float8E4M3FNType(f8E4M3FN)
         # CHECK: Float8E5M2Type(f8E5M2)
@@ -707,6 +713,9 @@ def print_downcasted(typ):
         # CHECK: F64Type
         # CHECK: F64Type(f64)
         print_downcasted(F64Type.get())
+        # CHECK: Float8E3M4Type
+        # CHECK: Float8E3M4Type(f8E3M4)
+        print_downcasted(Float8E3M4Type.get())
         # CHECK: Float8E4M3B11FNUZType
         # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
         print_downcasted(Float8E4M3B11FNUZType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index ed0ee431fd7d8..e7c526842439b 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -56,6 +56,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
     "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
     "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
     "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"',
+    "mlir::Float8E3M4Type": '"f8E3M4"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::FloatTF32Type": '"tf32"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index a657874f894b7..b5926d75da4f2 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -230,8 +230,8 @@ const common = {
   integer_type : $ =>
       token...
[truncated]

@apivovarov
Copy link
Member Author

Hi Maksim, whenever you have time, could you please review this pull request? It's essentially a clone of the previous one - #97118 [MLIR] Add f8E4M3 IEEE 754 type.
@makslevental

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's too bad this stuff isn't more generative/generated? Like what a tedium updating all these files every time you want to add another float type.

@apivovarov
Copy link
Member Author

apivovarov commented Aug 2, 2024

I planned to add f8E4M3 and f8E3M4 types

@apivovarov apivovarov closed this Aug 2, 2024
@apivovarov apivovarov reopened this Aug 2, 2024
@apivovarov apivovarov merged commit eef1d7e into llvm:main Aug 2, 2024
19 checks passed
@apivovarov apivovarov deleted the f8E3M4_mlir branch August 2, 2024 07:22
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 3, 2024
### Summary
This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point
types to StableHLO.
Feedback welcome, see [RFC: Float8E4M3 and
Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md)
for more details.

### References and Links
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- [RFC: FP8 in
StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
- StableHLO [PR-2482](#2482)
Add f8E4M3 and f8E3M4 types support
- [Amazon EC2 Trn1
Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 4, 2024
This PR adds f8E4M3 and f8E3M4 types support.

f8E4M3 and f8E3M4 types follow IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](#2486)
[RFC] Add f8E4M3 and f8E3M4 types support
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
copybara-service bot pushed a commit to google/tsl that referenced this pull request Sep 30, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Sep 30, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 30, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants