Skip to content

Commit eef1d7e

Browse files
authored
[MLIR] Add f8E3M4 IEEE 754 type (#101230)
This PR adds `f8E3M4` type to mlir. `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - [PR-99698](#99698) [APFloat] Add support for f8E3M4 IEEE 754 type - [PR-97118](#97118) [MLIR] Add f8E4M3 IEEE 754 type
1 parent e9c20b9 commit eef1d7e

File tree

24 files changed

+133
-9
lines changed

24 files changed

+133
-9
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
139139
/// context.
140140
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
141141

142+
/// Returns the typeID of an Float8E3M4 type.
143+
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
144+
145+
/// Checks whether the given type is an f8E3M4 type.
146+
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
147+
148+
/// Creates an f8E3M4 type in the given context. The type is owned by the
149+
/// context.
150+
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
151+
142152
/// Returns the typeID of an BFloat16 type.
143153
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
144154

mlir/include/mlir/IR/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class Builder {
6666
FloatType getFloat8E5M2FNUZType();
6767
FloatType getFloat8E4M3FNUZType();
6868
FloatType getFloat8E4M3B11FNUZType();
69+
FloatType getFloat8E3M4Type();
6970
FloatType getBF16Type();
7071
FloatType getF16Type();
7172
FloatType getTF32Type();

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class FloatType : public Type {
6666
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
6767
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
6868
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
69+
static FloatType getFloat8E3M4(MLIRContext *ctx);
6970

7071
/// Methods for support type inquiry through isa, cast, and dyn_cast.
7172
static bool classof(Type type);
@@ -411,10 +412,11 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
411412
}
412413

413414
inline bool FloatType::classof(Type type) {
414-
return llvm::isa<
415-
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
416-
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
417-
FloatTF32Type, Float32Type, Float64Type, Float80Type, Float128Type>(type);
415+
return llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
416+
Float8E5M2FNUZType, Float8E4M3FNUZType,
417+
Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
418+
Float16Type, FloatTF32Type, Float32Type, Float64Type,
419+
Float80Type, Float128Type>(type);
418420
}
419421

420422
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -441,6 +443,10 @@ inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
441443
return Float8E4M3B11FNUZType::get(ctx);
442444
}
443445

446+
inline FloatType FloatType::getFloat8E3M4(MLIRContext *ctx) {
447+
return Float8E3M4Type::get(ctx);
448+
}
449+
444450
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
445451
return BFloat16Type::get(ctx);
446452
}

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,25 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B1
213213
}];
214214
}
215215

216+
//===----------------------------------------------------------------------===//
217+
// Float8E3M4Type
218+
219+
def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
220+
let summary = "8-bit floating point with 3 bits exponent and 4 bit mantissa";
221+
let description = [{
222+
An 8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits
223+
mantissa. This is not a standard type as defined by IEEE-754, but it
224+
follows similar conventions with the following characteristics:
225+
226+
* bit encoding: S1E3M4
227+
* exponent bias: 3
228+
* infinities: supported with exponent set to all 1s and mantissa 0s
229+
* NaNs: supported with exponent bits set to all 1s and mantissa values of
230+
{0,1}⁴ except S.111.0000
231+
* denormals when exponent is 0
232+
}];
233+
}
234+
216235
//===----------------------------------------------------------------------===//
217236
// BFloat16Type
218237

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ t
342342
BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
343343
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
344344
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
345+
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
346+
BuildableType<"$_builder.getFloat8E3M4Type()">;
345347

346348
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
347349
"complex-type", "::mlir::ComplexType">;

mlir/include/mlir/IR/Types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Type {
131131
bool isFloat8E5M2FNUZ() const;
132132
bool isFloat8E4M3FNUZ() const;
133133
bool isFloat8E4M3B11FNUZ() const;
134+
bool isFloat8E3M4() const;
134135
bool isBF16() const;
135136
bool isF16() const;
136137
bool isTF32() const;

mlir/lib/AsmParser/TokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ TOK_KEYWORD(f8E4M3FN)
100100
TOK_KEYWORD(f8E5M2FNUZ)
101101
TOK_KEYWORD(f8E4M3FNUZ)
102102
TOK_KEYWORD(f8E4M3B11FNUZ)
103+
TOK_KEYWORD(f8E3M4)
103104
TOK_KEYWORD(f128)
104105
TOK_KEYWORD(false)
105106
TOK_KEYWORD(floordiv)

mlir/lib/AsmParser/TypeParser.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
4545
case Token::kw_f8E5M2FNUZ:
4646
case Token::kw_f8E4M3FNUZ:
4747
case Token::kw_f8E4M3B11FNUZ:
48+
case Token::kw_f8E3M4:
4849
case Token::kw_bf16:
4950
case Token::kw_f16:
5051
case Token::kw_tf32:
@@ -320,6 +321,9 @@ Type Parser::parseNonFunctionType() {
320321
case Token::kw_f8E4M3B11FNUZ:
321322
consumeToken(Token::kw_f8E4M3B11FNUZ);
322323
return builder.getFloat8E4M3B11FNUZType();
324+
case Token::kw_f8E3M4:
325+
consumeToken(Token::kw_f8E3M4);
326+
return builder.getFloat8E3M4Type();
323327
case Token::kw_bf16:
324328
consumeToken(Token::kw_bf16);
325329
return builder.getBF16Type();

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType
246246
}
247247
};
248248

249+
/// Floating Point Type subclass - Float8E3M4Type.
250+
class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
251+
public:
252+
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
253+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
254+
mlirFloat8E3M4TypeGetTypeID;
255+
static constexpr const char *pyClassName = "Float8E3M4Type";
256+
using PyConcreteType::PyConcreteType;
257+
258+
static void bindDerived(ClassTy &c) {
259+
c.def_static(
260+
"get",
261+
[](DefaultingPyMlirContext context) {
262+
MlirType t = mlirFloat8E3M4TypeGet(context->get());
263+
return PyFloat8E3M4Type(context->getRef(), t);
264+
},
265+
py::arg("context") = py::none(), "Create a float8_e3m4 type.");
266+
}
267+
};
268+
249269
/// Floating Point Type subclass - BF16Type.
250270
class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
251271
public:
@@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) {
864884
PyFloat8E4M3FNUZType::bind(m);
865885
PyFloat8E4M3B11FNUZType::bind(m);
866886
PyFloat8E5M2FNUZType::bind(m);
887+
PyFloat8E3M4Type::bind(m);
867888
PyBF16Type::bind(m);
868889
PyF16Type::bind(m);
869890
PyTF32Type::bind(m);

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
157157
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
158158
}
159159

160+
MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
161+
return wrap(Float8E3M4Type::getTypeID());
162+
}
163+
164+
bool mlirTypeIsAFloat8E3M4(MlirType type) {
165+
return unwrap(type).isFloat8E3M4();
166+
}
167+
168+
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
169+
return wrap(FloatType::getFloat8E3M4(unwrap(ctx)));
170+
}
171+
160172
MlirTypeID mlirBFloat16TypeGetTypeID() {
161173
return wrap(BFloat16Type::getTypeID());
162174
}

0 commit comments

Comments
 (0)