From 396f4f987122ea0574e69c41caca4fff842edd8d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 27 Nov 2025 01:14:16 +0000 Subject: [PATCH] [mlir][arith] Add support for `cmpf` to `ArithToAPFloat` x --- .../ArithToAPFloat/ArithToAPFloat.cpp | 152 +++++++++++++++++- mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 11 ++ .../ArithToApfloat/arith-to-apfloat.mlir | 15 ++ .../Arith/CPU/test-apfloat-emulation.mlir | 4 + 4 files changed, 177 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 81fbdb1611deb..566632bd8707f 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, } /// Helper function to look up or create the symbol for a runtime library -/// function with the given parameter types. Always returns an int64_t. +/// function with the given parameter types. Returns an int64_t, unless a +/// different result type is specified. static FailureOr lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, - SymbolTableCollection *symbolTables = nullptr) { - auto i64Type = IntegerType::get(symTable->getContext(), 64); - + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); - auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type}); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); FailureOr func = lookupFnDecl(symTable, funcName, funcT, symbolTables); // Failed due to type mismatch. @@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern { bool isUnsigned; }; +struct CmpFOpToAPFloatConversion final : OpRewritePattern { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Cast operands to 64-bit integers. + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto floatTy = cast(op.getLhs().getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result matches + // the given `val`. + auto checkResult = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result matches + // any of the given `vals`. + std::function)> checkResults = + [&](ArrayRef vals) { + Value first = checkResult(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkResults(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest).getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkResult(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkResult(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = checkResults( + {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkResult(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = + checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = + checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = + checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkResult(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase { @@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() { /*isUnsigned=*/false); patterns.add>(context, getOperation(), /*isUnsigned=*/true); + patterns.add(context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 44980ccd77491..77f7137264888 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( llvm::RoundingMode::NearestTiesToEven); return result.bitcastToAPInt().getZExtValue(); } + +MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, + uint64_t a, + uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); + return static_cast(x.compare(y)); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index d71d81dddcd4f..78ce3640ecc67 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) { %0 = arith.uitofp %arg0 : i32 to f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8 +// CHECK: %[[c3:.*]] = arith.constant 3 : i8 +// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8 +// CHECK: %[[c0:.*]] = arith.constant 0 : i8 +// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8 +// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1 +func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 8046610d479a8..433d058d025cf 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,6 +43,10 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: 1 + %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN + vector.print %cmp1 : i1 + // CHECK-NEXT: 1 // Bit pattern: 01, interpreted as signed integer: 1 %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2