Skip to content

Conversation

@matthias-springer
Copy link
Member

Add support for arith.cmpf.

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-execution-engine

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add support for arith.cmpf.


Full diff: https://github.com/llvm/llvm-project/pull/169753.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+145-4)
  • (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+11)
  • (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+15)
  • (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+4)
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 81fbdb1611deb..20b99e4a6e516 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -45,11 +45,12 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
 static FailureOr<FuncOp>
 lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
                         StringRef name, TypeRange paramTypes,
-                        SymbolTableCollection *symbolTables = nullptr) {
-  auto i64Type = IntegerType::get(symTable->getContext(), 64);
-
+                        SymbolTableCollection *symbolTables = nullptr,
+                        Type resultType = {}) {
+  if (!resultType)
+    resultType = IntegerType::get(symTable->getContext(), 64);
   std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
-  auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
+  auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
   FailureOr<FuncOp> func =
       lookupFnDecl(symTable, funcName, funcT, symbolTables);
   // Failed due to type mismatch.
@@ -308,6 +309,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
   bool isUnsigned;
 };
 
+struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
+  CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+                            PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+
+  LogicalResult matchAndRewrite(arith::CmpFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat function from runtime library.
+    auto i1Type = IntegerType::get(symTable->getContext(), 1);
+    auto i8Type = IntegerType::get(symTable->getContext(), 8);
+    auto i32Type = IntegerType::get(symTable->getContext(), 32);
+    auto i64Type = IntegerType::get(symTable->getContext(), 64);
+    FailureOr<FuncOp> fn =
+        lookupOrCreateApFloatFn(rewriter, symTable, "compare",
+                                {i32Type, i64Type, i64Type}, nullptr, i8Type);
+    if (failed(fn))
+      return fn;
+
+    // Cast operands to 64-bit integers.
+    rewriter.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto floatTy = cast<FloatType>(op.getLhs().getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    Value lhsBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+    Value rhsBits = arith::ExtUIOp::create(
+        rewriter, loc, i64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+    // Call APFloat function.
+    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+    Value comparisonResult =
+        func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+                             SymbolRefAttr::get(*fn), params)
+            ->getResult(0);
+
+    // Generate an i1 SSA value that is "true" if the comparison result matches
+    // the given `val`.
+    auto checkValue = [&](llvm::APFloat::cmpResult val) {
+      return arith::CmpIOp::create(
+          rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+          arith::ConstantOp::create(
+              rewriter, loc, i8Type,
+              rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+              .getResult());
+    };
+    // Generate an i1 SSA value that is "true" if the comparison result matches
+    // any of the given `vals`.
+    std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkValues =
+        [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+          Value first = checkValue(vals.front());
+          if (vals.size() == 1)
+            return first;
+          Value rest = checkValues(vals.drop_front());
+          return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
+        };
+
+    // This switch-case statement was taken from arith::applyCmpPredicate.
+    Value result;
+    switch (op.getPredicate()) {
+    case arith::CmpFPredicate::AlwaysFalse:
+      result = arith::ConstantOp::create(rewriter, loc, i1Type,
+                                         rewriter.getIntegerAttr(i1Type, 0))
+                   .getResult();
+      break;
+    case arith::CmpFPredicate::OEQ:
+      result = checkValue(llvm::APFloat::cmpEqual);
+      break;
+    case arith::CmpFPredicate::OGT:
+      result = checkValue(llvm::APFloat::cmpGreaterThan);
+      break;
+    case arith::CmpFPredicate::OGE:
+      result =
+          checkValues({llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::OLT:
+      result = checkValue(llvm::APFloat::cmpLessThan);
+      break;
+    case arith::CmpFPredicate::OLE:
+      result =
+          checkValues({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::ONE:
+      // Not cmpUnordered and not cmpUnordered.
+      result = checkValues(
+          {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+      break;
+    case arith::CmpFPredicate::ORD:
+      // Not cmpUnordered.
+      result =
+          checkValues({llvm::APFloat::cmpLessThan,
+                       llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UEQ:
+      result =
+          checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UGT:
+      result = checkValues(
+          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+      break;
+    case arith::CmpFPredicate::UGE:
+      result =
+          checkValues({llvm::APFloat::cmpUnordered,
+                       llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::ULT:
+      result = checkValues(
+          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+      break;
+    case arith::CmpFPredicate::ULE:
+      result =
+          checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
+                       llvm::APFloat::cmpEqual});
+      break;
+    case arith::CmpFPredicate::UNE:
+      // Not cmpEqual.
+      result = checkValues({llvm::APFloat::cmpLessThan,
+                            llvm::APFloat::cmpGreaterThan,
+                            llvm::APFloat::cmpUnordered});
+      break;
+    case arith::CmpFPredicate::UNO:
+      result = checkValue(llvm::APFloat::cmpUnordered);
+      break;
+    case arith::CmpFPredicate::AlwaysTrue:
+      result = arith::ConstantOp::create(rewriter, loc, i1Type,
+                                         rewriter.getIntegerAttr(i1Type, 1))
+                   .getResult();
+      break;
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+  SymbolOpInterface symTable;
+};
+
 namespace {
 struct ArithToAPFloatConversionPass final
     : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +480,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
                                                    /*isUnsigned=*/false);
   patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
                                                    /*isUnsigned=*/true);
+  patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
   LogicalResult result = success();
   ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
     if (diag.getSeverity() == DiagnosticSeverity::Error) {
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 44980ccd77491..77f7137264888 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
                           llvm::RoundingMode::NearestTiesToEven);
   return result.bitcastToAPInt().getZExtValue();
 }
+
+MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
+                                                          uint64_t a,
+                                                          uint64_t b) {
+  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
+      static_cast<llvm::APFloatBase::Semantics>(semantics));
+  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
+  llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
+  llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
+  return static_cast<int8_t>(x.compare(y));
+}
 }
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index d71d81dddcd4f..78ce3640ecc67 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
   %0 = arith.uitofp %arg0 : i32 to f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
+// CHECK: %[[sem:.*]] = arith.constant 18 : i32
+// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
+// CHECK: %[[c3:.*]] = arith.constant 3 : i8
+// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
+// CHECK: %[[c0:.*]] = arith.constant 0 : i8
+// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
+// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
+func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
+  %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 8046610d479a8..433d058d025cf 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -43,6 +43,10 @@ func.func @entry() {
   %cvt = arith.truncf %b2 : f32 to f8E4M3FN
   vector.print %cvt : f8E4M3FN
 
+  // CHECK-NEXT: 1
+  %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
+  vector.print %cmp1 : i1
+
   // CHECK-NEXT: 1
   // Bit pattern: 01, interpreted as signed integer: 1
   %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2

@matthias-springer matthias-springer force-pushed the users/matthias-springer/apfloat_cmpf branch from 97977f8 to 98853ee Compare November 27, 2025 01:16
@matthias-springer matthias-springer force-pushed the users/matthias-springer/apfloat_cmpf branch from 98853ee to 396f4f9 Compare November 27, 2025 01:19
@matthias-springer matthias-springer merged commit 4d7abe5 into main Dec 1, 2025
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/apfloat_cmpf branch December 1, 2025 08:12
aahrun pushed a commit to aahrun/llvm-project that referenced this pull request Dec 1, 2025
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Dec 3, 2025
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
honeygoyal pushed a commit to honeygoyal/llvm-project that referenced this pull request Dec 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants