Skip to content

Commit 97977f8

Browse files
[mlir][arith] Add support for cmpf to ArithToAPFloat
x
1 parent 5d38cdd commit 97977f8

File tree

4 files changed

+175
-4
lines changed

4 files changed

+175
-4
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4545
static FailureOr<FuncOp>
4646
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
4747
StringRef name, TypeRange paramTypes,
48-
SymbolTableCollection *symbolTables = nullptr) {
49-
auto i64Type = IntegerType::get(symTable->getContext(), 64);
50-
48+
SymbolTableCollection *symbolTables = nullptr,
49+
Type resultType = {}) {
50+
if (!resultType)
51+
resultType = IntegerType::get(symTable->getContext(), 64);
5152
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
52-
auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
53+
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
5354
FailureOr<FuncOp> func =
5455
lookupFnDecl(symTable, funcName, funcT, symbolTables);
5556
// Failed due to type mismatch.
@@ -308,6 +309,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
308309
bool isUnsigned;
309310
};
310311

312+
struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
313+
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
314+
PatternBenefit benefit = 1)
315+
: OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
316+
317+
LogicalResult matchAndRewrite(arith::CmpFOp op,
318+
PatternRewriter &rewriter) const override {
319+
// Get APFloat function from runtime library.
320+
auto i1Type = IntegerType::get(symTable->getContext(), 1);
321+
auto i8Type = IntegerType::get(symTable->getContext(), 8);
322+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
323+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
324+
FailureOr<FuncOp> fn =
325+
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
326+
{i32Type, i64Type, i64Type}, nullptr, i8Type);
327+
if (failed(fn))
328+
return fn;
329+
330+
// Cast operands to 64-bit integers.
331+
rewriter.setInsertionPoint(op);
332+
Location loc = op.getLoc();
333+
auto floatTy = cast<FloatType>(op.getLhs().getType());
334+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
335+
Value lhsBits = arith::ExtUIOp::create(
336+
rewriter, loc, i64Type,
337+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
338+
Value rhsBits = arith::ExtUIOp::create(
339+
rewriter, loc, i64Type,
340+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
341+
342+
// Call APFloat function.
343+
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
344+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
345+
Value comparisonResult =
346+
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
347+
SymbolRefAttr::get(*fn), params)
348+
->getResult(0);
349+
350+
// Generate an i1 SSA value that is "true" if the comparison result matches
351+
// the given `val`.
352+
auto checkValue = [&](llvm::APFloat::cmpResult val) {
353+
return arith::CmpIOp::create(
354+
rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
355+
arith::ConstantOp::create(
356+
rewriter, loc, i8Type,
357+
rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
358+
.getResult());
359+
};
360+
// Generate an i1 SSA value that is "true" if the comparison result matches
361+
// any of the given `vals`.
362+
std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkValues =
363+
[&](ArrayRef<llvm::APFloat::cmpResult> vals) {
364+
Value first = checkValue(vals.front());
365+
if (vals.size() == 1)
366+
return first;
367+
Value rest = checkValues(vals.drop_front());
368+
return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
369+
};
370+
371+
// This switch-case statement was taken from arith::applyCmpPredicate.
372+
Value result;
373+
switch (op.getPredicate()) {
374+
case arith::CmpFPredicate::AlwaysFalse:
375+
result = arith::ConstantOp::create(rewriter, loc, i1Type,
376+
rewriter.getIntegerAttr(i1Type, 0))
377+
.getResult();
378+
break;
379+
case arith::CmpFPredicate::OEQ:
380+
result = checkValue(llvm::APFloat::cmpEqual);
381+
break;
382+
case arith::CmpFPredicate::OGT:
383+
result = checkValue(llvm::APFloat::cmpGreaterThan);
384+
break;
385+
case arith::CmpFPredicate::OGE:
386+
result =
387+
checkValues({llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
388+
break;
389+
case arith::CmpFPredicate::OLT:
390+
result = checkValue(llvm::APFloat::cmpLessThan);
391+
break;
392+
case arith::CmpFPredicate::OLE:
393+
result =
394+
checkValues({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
395+
break;
396+
case arith::CmpFPredicate::ONE:
397+
// Not cmpUnordered and not cmpUnordered.
398+
result = checkValues(
399+
{llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
400+
break;
401+
case arith::CmpFPredicate::ORD:
402+
// Not cmpUnordered.
403+
result =
404+
checkValues({llvm::APFloat::cmpLessThan,
405+
llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
406+
break;
407+
case arith::CmpFPredicate::UEQ:
408+
result =
409+
checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
410+
break;
411+
case arith::CmpFPredicate::UGT:
412+
result = checkValues(
413+
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
414+
break;
415+
case arith::CmpFPredicate::UGE:
416+
result =
417+
checkValues({llvm::APFloat::cmpUnordered,
418+
llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
419+
break;
420+
case arith::CmpFPredicate::ULT:
421+
result = checkValues(
422+
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
423+
break;
424+
case arith::CmpFPredicate::ULE:
425+
result =
426+
checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
427+
llvm::APFloat::cmpEqual});
428+
break;
429+
case arith::CmpFPredicate::UNE:
430+
// Not cmpEqual.
431+
result = checkValues({llvm::APFloat::cmpLessThan,
432+
llvm::APFloat::cmpGreaterThan,
433+
llvm::APFloat::cmpUnordered});
434+
break;
435+
case arith::CmpFPredicate::UNO:
436+
result = checkValue(llvm::APFloat::cmpUnordered);
437+
break;
438+
case arith::CmpFPredicate::AlwaysTrue:
439+
result = arith::ConstantOp::create(rewriter, loc, i1Type,
440+
rewriter.getIntegerAttr(i1Type, 1))
441+
.getResult();
442+
break;
443+
}
444+
rewriter.replaceOp(op, result);
445+
return success();
446+
}
447+
448+
SymbolOpInterface symTable;
449+
};
450+
311451
namespace {
312452
struct ArithToAPFloatConversionPass final
313453
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +480,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
340480
/*isUnsigned=*/false);
341481
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
342482
/*isUnsigned=*/true);
483+
patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
343484
LogicalResult result = success();
344485
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
345486
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
131131
llvm::RoundingMode::NearestTiesToEven);
132132
return result.bitcastToAPInt().getZExtValue();
133133
}
134+
135+
MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
136+
uint64_t a,
137+
uint64_t b) {
138+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
139+
static_cast<llvm::APFloatBase::Semantics>(semantics));
140+
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
141+
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
142+
llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
143+
return static_cast<int8_t>(x.compare(y));
144+
}
134145
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
198198
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
199199
return
200200
}
201+
202+
// -----
203+
204+
// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
205+
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
206+
// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
207+
// CHECK: %[[c3:.*]] = arith.constant 3 : i8
208+
// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
209+
// CHECK: %[[c0:.*]] = arith.constant 0 : i8
210+
// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
211+
// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
212+
func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
213+
%0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
214+
return
215+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ func.func @entry() {
4343
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
4444
vector.print %cvt : f8E4M3FN
4545

46+
// CHECK-NEXT: 1
47+
%cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
48+
vector.print %cmp1 : i1
49+
4650
// CHECK-NEXT: 1
4751
// Bit pattern: 01, interpreted as signed integer: 1
4852
%cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2

0 commit comments

Comments
 (0)