@@ -45,11 +45,12 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4545static FailureOr<FuncOp>
4646lookupOrCreateApFloatFn (OpBuilder &b, SymbolOpInterface symTable,
4747 StringRef name, TypeRange paramTypes,
48- SymbolTableCollection *symbolTables = nullptr ) {
49- auto i64Type = IntegerType::get (symTable->getContext (), 64 );
50-
48+ SymbolTableCollection *symbolTables = nullptr ,
49+ Type resultType = {}) {
50+ if (!resultType)
51+ resultType = IntegerType::get (symTable->getContext (), 64 );
5152 std::string funcName = (llvm::Twine (" _mlir_apfloat_" ) + name).str ();
52- auto funcT = FunctionType::get (b.getContext (), paramTypes, {i64Type });
53+ auto funcT = FunctionType::get (b.getContext (), paramTypes, {resultType });
5354 FailureOr<FuncOp> func =
5455 lookupFnDecl (symTable, funcName, funcT, symbolTables);
5556 // Failed due to type mismatch.
@@ -308,6 +309,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
308309 bool isUnsigned;
309310};
310311
312+ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
313+ CmpFOpToAPFloatConversion (MLIRContext *context, SymbolOpInterface symTable,
314+ PatternBenefit benefit = 1 )
315+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
316+
317+ LogicalResult matchAndRewrite (arith::CmpFOp op,
318+ PatternRewriter &rewriter) const override {
319+ // Get APFloat function from runtime library.
320+ auto i1Type = IntegerType::get (symTable->getContext (), 1 );
321+ auto i8Type = IntegerType::get (symTable->getContext (), 8 );
322+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
323+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
324+ FailureOr<FuncOp> fn =
325+ lookupOrCreateApFloatFn (rewriter, symTable, " compare" ,
326+ {i32Type, i64Type, i64Type}, nullptr , i8Type);
327+ if (failed (fn))
328+ return fn;
329+
330+ // Cast operands to 64-bit integers.
331+ rewriter.setInsertionPoint (op);
332+ Location loc = op.getLoc ();
333+ auto floatTy = cast<FloatType>(op.getLhs ().getType ());
334+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
335+ Value lhsBits = arith::ExtUIOp::create (
336+ rewriter, loc, i64Type,
337+ arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
338+ Value rhsBits = arith::ExtUIOp::create (
339+ rewriter, loc, i64Type,
340+ arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
341+
342+ // Call APFloat function.
343+ Value semValue = getSemanticsValue (rewriter, loc, floatTy);
344+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
345+ Value comparisonResult =
346+ func::CallOp::create (rewriter, loc, TypeRange (i8Type),
347+ SymbolRefAttr::get (*fn), params)
348+ ->getResult (0 );
349+
350+ // Generate an i1 SSA value that is "true" if the comparison result matches
351+ // the given `val`.
352+ auto checkValue = [&](llvm::APFloat::cmpResult val) {
353+ return arith::CmpIOp::create (
354+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
355+ arith::ConstantOp::create (
356+ rewriter, loc, i8Type,
357+ rewriter.getIntegerAttr (i8Type, static_cast <int8_t >(val)))
358+ .getResult ());
359+ };
360+ // Generate an i1 SSA value that is "true" if the comparison result matches
361+ // any of the given `vals`.
362+ std::function<Value (ArrayRef<llvm::APFloat::cmpResult>)> checkValues =
363+ [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
364+ Value first = checkValue (vals.front ());
365+ if (vals.size () == 1 )
366+ return first;
367+ Value rest = checkValues (vals.drop_front ());
368+ return arith::OrIOp::create (rewriter, loc, first, rest).getResult ();
369+ };
370+
371+ // This switch-case statement was taken from arith::applyCmpPredicate.
372+ Value result;
373+ switch (op.getPredicate ()) {
374+ case arith::CmpFPredicate::AlwaysFalse:
375+ result = arith::ConstantOp::create (rewriter, loc, i1Type,
376+ rewriter.getIntegerAttr (i1Type, 0 ))
377+ .getResult ();
378+ break ;
379+ case arith::CmpFPredicate::OEQ:
380+ result = checkValue (llvm::APFloat::cmpEqual);
381+ break ;
382+ case arith::CmpFPredicate::OGT:
383+ result = checkValue (llvm::APFloat::cmpGreaterThan);
384+ break ;
385+ case arith::CmpFPredicate::OGE:
386+ result =
387+ checkValues ({llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
388+ break ;
389+ case arith::CmpFPredicate::OLT:
390+ result = checkValue (llvm::APFloat::cmpLessThan);
391+ break ;
392+ case arith::CmpFPredicate::OLE:
393+ result =
394+ checkValues ({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
395+ break ;
396+ case arith::CmpFPredicate::ONE:
397+ // Not cmpUnordered and not cmpUnordered.
398+ result = checkValues (
399+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
400+ break ;
401+ case arith::CmpFPredicate::ORD:
402+ // Not cmpUnordered.
403+ result =
404+ checkValues ({llvm::APFloat::cmpLessThan,
405+ llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
406+ break ;
407+ case arith::CmpFPredicate::UEQ:
408+ result =
409+ checkValues ({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
410+ break ;
411+ case arith::CmpFPredicate::UGT:
412+ result = checkValues (
413+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
414+ break ;
415+ case arith::CmpFPredicate::UGE:
416+ result =
417+ checkValues ({llvm::APFloat::cmpUnordered,
418+ llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
419+ break ;
420+ case arith::CmpFPredicate::ULT:
421+ result = checkValues (
422+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
423+ break ;
424+ case arith::CmpFPredicate::ULE:
425+ result =
426+ checkValues ({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
427+ llvm::APFloat::cmpEqual});
428+ break ;
429+ case arith::CmpFPredicate::UNE:
430+ // Not cmpEqual.
431+ result = checkValues ({llvm::APFloat::cmpLessThan,
432+ llvm::APFloat::cmpGreaterThan,
433+ llvm::APFloat::cmpUnordered});
434+ break ;
435+ case arith::CmpFPredicate::UNO:
436+ result = checkValue (llvm::APFloat::cmpUnordered);
437+ break ;
438+ case arith::CmpFPredicate::AlwaysTrue:
439+ result = arith::ConstantOp::create (rewriter, loc, i1Type,
440+ rewriter.getIntegerAttr (i1Type, 1 ))
441+ .getResult ();
442+ break ;
443+ }
444+ rewriter.replaceOp (op, result);
445+ return success ();
446+ }
447+
448+ SymbolOpInterface symTable;
449+ };
450+
311451namespace {
312452struct ArithToAPFloatConversionPass final
313453 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +480,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
340480 /* isUnsigned=*/ false );
341481 patterns.add <IntToFpConversion<arith::UIToFPOp>>(context, getOperation (),
342482 /* isUnsigned=*/ true );
483+ patterns.add <CmpFOpToAPFloatConversion>(context, getOperation ());
343484 LogicalResult result = success ();
344485 ScopedDiagnosticHandler scopedHandler (context, [&result](Diagnostic &diag) {
345486 if (diag.getSeverity () == DiagnosticSeverity::Error) {
0 commit comments