Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 147 additions & 5 deletions mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
}

/// Helper function to look up or create the symbol for a runtime library
/// function with the given parameter types. Always returns an int64_t.
/// function with the given parameter types. Returns an int64_t, unless a
/// different result type is specified.
static FailureOr<FuncOp>
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
StringRef name, TypeRange paramTypes,
SymbolTableCollection *symbolTables = nullptr) {
auto i64Type = IntegerType::get(symTable->getContext(), 64);

SymbolTableCollection *symbolTables = nullptr,
Type resultType = {}) {
if (!resultType)
resultType = IntegerType::get(symTable->getContext(), 64);
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
FailureOr<FuncOp> func =
lookupFnDecl(symTable, funcName, funcT, symbolTables);
// Failed due to type mismatch.
Expand Down Expand Up @@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
bool isUnsigned;
};

struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
PatternBenefit benefit = 1)
: OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}

LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
auto i8Type = IntegerType::get(symTable->getContext(), 8);
auto i32Type = IntegerType::get(symTable->getContext(), 32);
auto i64Type = IntegerType::get(symTable->getContext(), 64);
FailureOr<FuncOp> fn =
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
{i32Type, i64Type, i64Type}, nullptr, i8Type);
if (failed(fn))
return fn;

// Cast operands to 64-bit integers.
rewriter.setInsertionPoint(op);
Location loc = op.getLoc();
auto floatTy = cast<FloatType>(op.getLhs().getType());
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
Value lhsBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
Value rhsBits = arith::ExtUIOp::create(
rewriter, loc, i64Type,
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));

// Call APFloat function.
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
Value comparisonResult =
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
SymbolRefAttr::get(*fn), params)
->getResult(0);

// Generate an i1 SSA value that is "true" if the comparison result matches
// the given `val`.
auto checkResult = [&](llvm::APFloat::cmpResult val) {
return arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
arith::ConstantOp::create(
rewriter, loc, i8Type,
rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
.getResult());
};
// Generate an i1 SSA value that is "true" if the comparison result matches
// any of the given `vals`.
std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
[&](ArrayRef<llvm::APFloat::cmpResult> vals) {
Value first = checkResult(vals.front());
if (vals.size() == 1)
return first;
Value rest = checkResults(vals.drop_front());
return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
};

// This switch-case statement was taken from arith::applyCmpPredicate.
Value result;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse:
result = arith::ConstantOp::create(rewriter, loc, i1Type,
rewriter.getIntegerAttr(i1Type, 0))
.getResult();
break;
case arith::CmpFPredicate::OEQ:
result = checkResult(llvm::APFloat::cmpEqual);
break;
case arith::CmpFPredicate::OGT:
result = checkResult(llvm::APFloat::cmpGreaterThan);
break;
case arith::CmpFPredicate::OGE:
result = checkResults(
{llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::OLT:
result = checkResult(llvm::APFloat::cmpLessThan);
break;
case arith::CmpFPredicate::OLE:
result =
checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::ONE:
// Not cmpUnordered and not cmpUnordered.
result = checkResults(
{llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
break;
case arith::CmpFPredicate::ORD:
// Not cmpUnordered.
result = checkResults({llvm::APFloat::cmpLessThan,
llvm::APFloat::cmpGreaterThan,
llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::UEQ:
result =
checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::UGT:
result = checkResults(
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
break;
case arith::CmpFPredicate::UGE:
result = checkResults({llvm::APFloat::cmpUnordered,
llvm::APFloat::cmpGreaterThan,
llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::ULT:
result = checkResults(
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
break;
case arith::CmpFPredicate::ULE:
result =
checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
llvm::APFloat::cmpEqual});
break;
case arith::CmpFPredicate::UNE:
// Not cmpEqual.
result = checkResults({llvm::APFloat::cmpLessThan,
llvm::APFloat::cmpGreaterThan,
llvm::APFloat::cmpUnordered});
break;
case arith::CmpFPredicate::UNO:
result = checkResult(llvm::APFloat::cmpUnordered);
break;
case arith::CmpFPredicate::AlwaysTrue:
result = arith::ConstantOp::create(rewriter, loc, i1Type,
rewriter.getIntegerAttr(i1Type, 1))
.getResult();
break;
}
rewriter.replaceOp(op, result);
return success();
}

SymbolOpInterface symTable;
};

namespace {
struct ArithToAPFloatConversionPass final
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
Expand Down Expand Up @@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
/*isUnsigned=*/false);
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
/*isUnsigned=*/true);
patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
LogicalResult result = success();
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
if (diag.getSeverity() == DiagnosticSeverity::Error) {
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/ExecutionEngine/APFloatWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
llvm::RoundingMode::NearestTiesToEven);
return result.bitcastToAPInt().getZExtValue();
}

MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
uint64_t a,
uint64_t b) {
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
static_cast<llvm::APFloatBase::Semantics>(semantics));
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
return static_cast<int8_t>(x.compare(y));
}
}
15 changes: 15 additions & 0 deletions mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
return
}

// -----

// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
// CHECK: %[[c3:.*]] = arith.constant 3 : i8
// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
// CHECK: %[[c0:.*]] = arith.constant 0 : i8
// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
%0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func.func @entry() {
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
vector.print %cvt : f8E4M3FN

// CHECK-NEXT: 1
%cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
vector.print %cmp1 : i1

// CHECK-NEXT: 1
// Bit pattern: 01, interpreted as signed integer: 1
%cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2
Expand Down