diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 8d81d8ec14ee7..5aaac8d8e3dc5 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts. void populateExpandBFloat16Patterns(RewritePatternSet &patterns); +/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts. +void populateExpandF8E8M0Patterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index d026d494cb50c..e14b2aeee1c69 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td" def ArithExpandOpsPass : Pass<"arith-expand"> { let summary = "Legalize Arith ops to be convertible to LLVM."; let dependentDialects = ["vector::VectorDialect"]; - let options = [ - Option<"includeBf16", "include-bf16", "bool", /*default=*/"false", - "Enable the BF16 expansion patterns">, + let options = + [Option<"includeBf16", "include-bf16", "bool", /*default=*/"false", + "Enable the BF16 expansion patterns">, + Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false", + "Enable the F8E8M0 expansion patterns">, ]; } diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 2d627e523cde5..95546bb09e765 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value, return rewriter.create(loc, attr); } +/// Creates shapedType using shape from cloneFrom and base type from cloneTo +static Type cloneToShapedType(Type cloneFrom, Type cloneTo) { + if (auto shapedTy = dyn_cast(cloneFrom)) { + return shapedTy.clone(cloneTo); + } + return cloneTo; +} + namespace { /// Expands CeilDivUIOp (n, m) into @@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32."); } - Type i16Ty = b.getI16Type(); - Type i32Ty = b.getI32Type(); - if (auto shapedTy = dyn_cast(operandTy)) { - i16Ty = shapedTy.clone(i16Ty); - i32Ty = shapedTy.clone(i32Ty); - } + Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); + Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Value bitcast = b.create(i16Ty, operand); Value exti = b.create(i32Ty, bitcast); @@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { op, "only applicable to default rounding mode."); } - Type i16Ty = b.getI16Type(); - Type i32Ty = b.getI32Type(); - Type f32Ty = b.getF32Type(); - if (auto shapedTy = dyn_cast(operandTy)) { - i16Ty = shapedTy.clone(i16Ty); - i32Ty = shapedTy.clone(i32Ty); - f32Ty = shapedTy.clone(f32Ty); - } + Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); + Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); // Algorithm borrowed from this excellent code: // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79 @@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { // Constant used to make the rounding bias. Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); // Constant used to generate a quiet NaN. - Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); + Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); // Small constants used to address bits. Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); @@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { // Now that the rounding-bias has been added, truncating the low bits // yields the correctly rounded result. Value biasedAndShifted = b.create(biased, c16); - Value normalCaseResult_i16 = + Value normalCaseResultI16 = b.create(i16Ty, biasedAndShifted); // Select either the above-computed result, or a quiet NaN constant // if the input was NaN. Value select = - b.create(isNan, c7FC0_i16, normalCaseResult_i16); + b.create(isNan, c7FC0I16, normalCaseResultI16); Value result = b.create(resultTy, select); rewriter.replaceOp(op, result); return success(); } }; +struct F8E8M0ExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!llvm::isa(operandETy)) { + return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU"); + } + + Type i8Ty = cloneToShapedType(operandTy, b.getI8Type()); + Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); + Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); + + Value bitcast = b.create(i8Ty, operand); + // create constants for NaNs + Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter); + Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter); + Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); + + Value exti = b.create(i32Ty, bitcast); + Value f32Bits = b.create(exti, cF32MantissaWidth); + + Value isNan = + b.create(arith::CmpIPredicate::eq, bitcast, cF8NaN); + // select for NaNs + f32Bits = b.create(isNan, cF32NaN, f32Bits); + Value result = b.create(f32Ty, f32Bits); + if (resultETy.getIntOrFloatBitWidth() < 32) { + result = b.create(resultTy, result); + } else if (resultETy.getIntOrFloatBitWidth() > 32) { + result = b.create(resultTy, result); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +/* +TruncF to F8E8M0 is expected to extract exponent bits out of F32 type +Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type, +they all map to NaN in F8E8M0 Type. +*/ +struct F8E8M0TruncFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value operand = op.getOperand(); + Type operandTy = operand.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultTy = op.getType(); + Type resultETy = getElementTypeOrSelf(resultTy); + if (!llvm::isa(resultETy)) { + return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU"); + } + + if (op.getRoundingmodeAttr()) { + return rewriter.notifyMatchFailure( + op, "only applicable to default rounding mode."); + } + + Type i8Ty = cloneToShapedType(operandTy, b.getI8Type()); + Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); + Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); + + if (operandETy.getIntOrFloatBitWidth() < 32) { + operand = b.create(f32Ty, operand); + } else if (operandETy.getIntOrFloatBitWidth() > 32) { + operand = b.create(f32Ty, operand); + } + Value f32Bits = b.create(i32Ty, operand); + Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); + Value f32SignExp = b.create(f32Bits, cF32MantissaWidth); + Value exp8Bits = b.create(i8Ty, f32SignExp); + Value result = b.create(resultTy, exp8Bits); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -353,20 +437,34 @@ struct ArithExpandOpsPass if (includeBf16) { arith::populateExpandBFloat16Patterns(patterns); - target.addDynamicallyLegalOp( - [](arith::ExtFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isBF16() && outETy.isF32()); - }); - - target.addDynamicallyLegalOp( - [](arith::TruncFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isF32() && outETy.isBF16()); - }); } + if (includeF8E8M0) { + arith::populateExpandF8E8M0Patterns(patterns); + } + + target.addDynamicallyLegalOp( + [=](arith::ExtFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isBF16() && outETy.isF32()); + if (includeF8E8M0) + legalTypes &= !llvm::isa(inETy); + return legalTypes; + }); + + target.addDynamicallyLegalOp( + [=](arith::TruncFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isF32() && outETy.isBF16()); + if (includeF8E8M0) + legalTypes &= !(llvm::isa(outETy)); + return legalTypes; + }); // clang-format on if (failed(applyPartialConversion(getOperation(), target, @@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { patterns.getContext()); } +void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index bdf022642b717..5b6badf13d763 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s // Test ceil divide with signed integer // CHECK-LABEL: func @ceildivi @@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> { // CHECK-LABEL: @truncf_vector_f32 // CHECK-NOT: arith.truncf +// ----- +func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU { + %0 = arith.truncf %arg0 : f32 to f8E8M0FNU + return %0 : f8E8M0FNU +} +// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU +// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32 +// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 +// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32 +// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8 +// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU +// CHECK: return %[[RESULT]] + +// ----- + +func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU { + %0 = arith.truncf %arg0 : f16 to f8E8M0FNU + return %0 : f8E8M0FNU +} +// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU +// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32 +// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32 +// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 +// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32 +// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8 +// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU +// CHECK: return %[[RESULT]] + +// ----- + +func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU> + return %0 : vector<4xf8E8M0FNU> +} + +// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU +// CHECK-NOT: arith.truncf + +// ----- + +func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> { + %0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU> + return %0 : vector<4xf8E8M0FNU> +} + +// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU +// CHECK-NOT: arith.truncf + +// ----- + +func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> { + %0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU> + return %0 : vector<4xf8E8M0FNU> +} + +// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU +// CHECK-NOT: arith.truncf + + +// ----- +func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 { + %0 = arith.extf %arg0 : f8E8M0FNU to f32 + return %0 : f32 +} + +// CHECK-LABLE: @extf_f8E8M0FNU_to_f32 +// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8 +// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32 +// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32 +// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32 +// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32 +// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8 +// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32 +// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32 +// CHECK: return %[[RESULT]] + +// ----- + +func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 { + %0 = arith.extf %arg0 : f8E8M0FNU to f16 + return %0 : f16 +} + +// CHECK-LABLE: @extf_f8E8M0FNU_to_f16 +// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8 +// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32 +// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32 +// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32 +// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32 +// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8 +// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32 +// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32 +// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16 +// CHECK: return %[[F16_RESULT]] + +// ----- + +func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32 +// CHECK-NOT: arith.extf + +// ----- + +func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> { + %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16> + return %0 : vector<4xf16> +} + +// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16 +// CHECK-NOT: arith.extf + +// ----- + +func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> { + %0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16> + return %0 : vector<4xbf16> +} + +// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16 +// CHECK-NOT: arith.extf + + // ----- func.func @maxsi(%a: i32, %b: i32) -> i32 {