@@ -33,37 +33,25 @@ using namespace mlir;
33
33
namespace {
34
34
35
35
// / Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
36
- // / AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
37
- // / `memref.generic_atomic_rmw` with the expanded code.
36
+ // / AtomicRMWOpLowering pattern, such as minimum and maximum operations for
37
+ // / floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
38
+ // / code.
38
39
// /
39
- // / %x = atomic_rmw " maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
40
+ // / %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
40
41
// /
41
42
// / will be lowered to
42
43
// /
43
44
// / %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
44
45
// / ^bb0(%current: f32):
45
- // / %cmp = arith.cmpf "ogt", %current, %fval : f32
46
- // / %new_value = select %cmp, %current, %fval : f32
47
- // / memref.atomic_yield %new_value : f32
46
+ // / %1 = arith.maximumf %current, %fval : f32
47
+ // / memref.atomic_yield %1 : f32
48
48
// / }
49
49
struct AtomicRMWOpConverter : public OpRewritePattern <memref::AtomicRMWOp> {
50
50
public:
51
51
using OpRewritePattern::OpRewritePattern;
52
52
53
53
LogicalResult matchAndRewrite (memref::AtomicRMWOp op,
54
54
PatternRewriter &rewriter) const final {
55
- arith::CmpFPredicate predicate;
56
- switch (op.getKind ()) {
57
- case arith::AtomicRMWKind::maximumf:
58
- predicate = arith::CmpFPredicate::OGT;
59
- break ;
60
- case arith::AtomicRMWKind::minimumf:
61
- predicate = arith::CmpFPredicate::OLT;
62
- break ;
63
- default :
64
- return failure ();
65
- }
66
-
67
55
auto loc = op.getLoc ();
68
56
auto genericOp = rewriter.create <memref::GenericAtomicRMWOp>(
69
57
loc, op.getMemref (), op.getIndices ());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
72
60
73
61
Value lhs = genericOp.getCurrentValue ();
74
62
Value rhs = op.getValue ();
75
- Value cmp = bodyBuilder.create <arith::CmpFOp>(loc, predicate, lhs, rhs);
76
- Value select = bodyBuilder.create <arith::SelectOp>(loc, cmp, lhs, rhs);
77
- bodyBuilder.create <memref::AtomicYieldOp>(loc, select );
63
+
64
+ Value arithOp =
65
+ mlir::arith::getReductionOp (op.getKind (), bodyBuilder, loc, lhs, rhs);
66
+ bodyBuilder.create <memref::AtomicYieldOp>(loc, arithOp);
78
67
79
68
rewriter.replaceOp (op, genericOp.getResult ());
80
69
return success ();
0 commit comments