Skip to content

[MLIR] Add f8E4M3 IEEE 754 type #97118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);

/// Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void);

/// Checks whether the given type is an f8E4M3 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type);

/// Creates an f8E4M3 type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx);

/// Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Builder {

// Types.
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
Expand Down
13 changes: 9 additions & 4 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class FloatType : public Type {
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
Expand Down Expand Up @@ -405,16 +406,20 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}

inline bool FloatType::classof(Type type) {
return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
Float16Type, FloatTF32Type, Float32Type, Float64Type,
Float80Type, Float128Type>(type);
return llvm::isa<
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
}

inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
return Float8E5M2Type::get(ctx);
}

inline FloatType FloatType::getFloat8E4M3(MLIRContext *ctx) {
return Float8E4M3Type::get(ctx);
}

inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
return Float8E4M3FNType::get(ctx);
}
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,25 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> {
}];
}

//===----------------------------------------------------------------------===//
// Float8E4M3Type

def Builtin_Float8E4M3 : Builtin_FloatType<"Float8E4M3", "f8E4M3"> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
mantissa. This is not a standard type as defined by IEEE-754, but it
follows similar conventions with the following characteristics:

* bit encoding: S1E4M3
* exponent bias: 7
* infinities: supported with exponent set to all 1s and mantissa 0s
* NaNs: supported with exponent bits set to all 1s and mantissa of
(001, 010, 011, 100, 101, 110, 111)
* denormals when exponent is 0
}];
}

//===----------------------------------------------------------------------===//
// Float8E4M3FNType

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
BuildableType<"$_builder.getFloat8E4M3Type()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class Type {
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3() const;
bool isFloat8E4M3FN() const;
bool isFloat8E5M2FNUZ() const;
bool isFloat8E4M3FNUZ() const;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f8E5M2)
TOK_KEYWORD(f8E4M3)
TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_vector:
case Token::inttype:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3:
case Token::kw_f8E4M3FN:
case Token::kw_f8E5M2FNUZ:
case Token::kw_f8E4M3FNUZ:
Expand Down Expand Up @@ -304,6 +305,9 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
return builder.getFloat8E4M3Type();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
Expand Down
23 changes: 22 additions & 1 deletion mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class PyFloat8E4M3FNType
}
};

/// Floating Point Type subclass - Float8M5E2Type.
/// Floating Point Type subclass - Float8E5M2Type.
class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
Expand All @@ -163,6 +163,26 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
}
};

/// Floating Point Type subclass - Float8E4M3Type.
class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3Type";
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E4M3TypeGet(context->get());
return PyFloat8E4M3Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3 type.");
}
};

/// Floating Point Type subclass - Float8E4M3FNUZ.
class PyFloat8E4M3FNUZType
: public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
Expand Down Expand Up @@ -840,6 +860,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
PyFloat8E4M3Type::bind(m);
PyFloat8E4M3FNUZType::bind(m);
PyFloat8E4M3B11FNUZType::bind(m);
PyFloat8E5M2FNUZType::bind(m);
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}

MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
return wrap(Float8E4M3Type::getTypeID());
}

bool mlirTypeIsAFloat8E4M3(MlirType type) {
return unwrap(type).isFloat8E4M3();
}

MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3(unwrap(ctx)));
}

MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
return wrap(Float8E4M3FNType::getTypeID());
}
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
}

Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ())
return IntegerType::get(&getContext(), type.getWidth());
return type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ FloatType Builder::getFloat8E5M2Type() {
return FloatType::getFloat8E5M2(context);
}

FloatType Builder::getFloat8E4M3Type() {
return FloatType::getFloat8E4M3(context);
}

FloatType Builder::getFloat8E4M3FNType() {
return FloatType::getFloat8E4M3FN(context);
}
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
if (llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
Float8E5M2FNUZType, Float8E4M3FNUZType, Float8E4M3B11FNUZType>(
*this))
return 8;
if (llvm::isa<Float16Type, BFloat16Type>(*this))
return 16;
Expand All @@ -107,6 +108,8 @@ unsigned FloatType::getWidth() {
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
if (llvm::isa<Float8E4M3Type>(*this))
return APFloat::Float8E4M3();
if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
if (llvm::isa<Float8E5M2FNUZType>(*this))
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class MLIRContextImpl {

/// Cached Type Instances.
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Expand Down Expand Up @@ -312,6 +313,7 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
//// Types.
/// Floating-point Types.
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
Expand Down Expand Up @@ -1012,6 +1014,9 @@ StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
return context->getImpl().f8E4M3Ty;
}
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
MLIRContext *Type::getContext() const { return getDialect().getContext(); }

bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
bool Type::isFloat8E5M2FNUZ() const {
return llvm::isa<Float8E5M2FNUZType>(*this);
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ __all__ = [
"Float8E4M3B11FNUZType",
"Float8E4M3FNType",
"Float8E4M3FNUZType",
"Float8E4M3Type",
"Float8E5M2FNUZType",
"Float8E5M2Type",
"FloatAttr",
Expand Down Expand Up @@ -1575,6 +1576,19 @@ class Float8E4M3FNUZType(FloatType):
@property
def typeid(self) -> TypeID: ...

class Float8E4M3Type(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3Type:
"""
Create a float8_e4m3 type.
"""
@staticmethod
def isinstance(other: Type) -> bool: ...
def __init__(self, cast_from_type: Type) -> None: ...
@property
def typeid(self) -> TypeID: ...

class Float8E5M2FNUZType(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions mlir/python/mlir/extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
F64Type,
Float8E4M3B11FNUZType,
Float8E4M3FNType,
Float8E4M3Type,
Float8E5M2Type,
FunctionType,
IndexType,
Expand Down Expand Up @@ -68,6 +69,7 @@ def ui(width):
bf16 = lambda: BF16Type.get()

f8E5M2 = lambda: Float8E5M2Type.get()
f8E4M3 = lambda: Float8E4M3Type.get()
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/IR/attribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3
float_attr = 2. : f8E4M3
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ llvm.mlir.global internal constant @string_const("foobar") : !llvm.array<6 x i8>
// CHECK: @int_global_undef = internal global i64 undef
llvm.mlir.global internal @int_global_undef() : i64

// CHECK: @f8E4M3_global_as_i8 = internal global i8 60
llvm.mlir.global internal @f8E4M3_global_as_i8(1.5 : f8E4M3) : i8

// CHECK: @f8E4M3FN_global_as_i8 = internal global i8 60
llvm.mlir.global internal @f8E4M3FN_global_as_i8(1.5 : f8E4M3FN) : i8

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def testTypeIsInstance():
def testFloatTypeSubclasses():
ctx = Context()
# CHECK: True
print(isinstance(Type.parse("f8E4M3", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
Expand Down Expand Up @@ -229,6 +231,8 @@ def testIndexType():
@run
def testFloatType():
with Context():
# CHECK: float: f8E4M3
print("float:", Float8E4M3Type.get())
# CHECK: float: f8E4M3FN
print("float:", Float8E4M3FNType.get())
# CHECK: float: f8E5M2
Expand Down Expand Up @@ -601,6 +605,7 @@ def testTypeIDs():
types = [
(IntegerType, IntegerType.get_signless(16)),
(IndexType, IndexType.get()),
(Float8E4M3Type, Float8E4M3Type.get()),
(Float8E4M3FNType, Float8E4M3FNType.get()),
(Float8E5M2Type, Float8E5M2Type.get()),
(Float8E4M3FNUZType, Float8E4M3FNUZType.get()),
Expand All @@ -624,6 +629,7 @@ def testTypeIDs():

# CHECK: IntegerType(i16)
# CHECK: IndexType(index)
# CHECK: Float8E4M3Type(f8E4M3)
# CHECK: Float8E4M3FNType(f8E4M3FN)
# CHECK: Float8E5M2Type(f8E5M2)
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
Expand Down Expand Up @@ -704,6 +710,9 @@ def print_downcasted(typ):
# CHECK: Float8E4M3B11FNUZType
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
print_downcasted(Float8E4M3B11FNUZType.get())
# CHECK: Float8E4M3Type
# CHECK: Float8E4M3Type(f8E4M3)
print_downcasted(Float8E4M3Type.get())
# CHECK: Float8E4M3FNType
# CHECK: Float8E4M3FNType(f8E4M3FN)
print_downcasted(Float8E4M3FNType.get())
Expand Down
Loading
Loading