Skip to content

Commit 6f4a528

Browse files
[mlir][memref] Use dedicated ops in AtomicRMWOpConverter (#66437)
This patch refactors the `AtomicRMWOpConverter` class to use the dedicated operations from Arith dialect instead of using `cmpf` + `select` pattern. Also, a test for `minimumf` kind of `atomic_rmw` has been added.
1 parent ca8cba7 commit 6f4a528

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,37 +33,25 @@ using namespace mlir;
3333
namespace {
3434

3535
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
36-
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
37-
/// `memref.generic_atomic_rmw` with the expanded code.
36+
/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
37+
/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
38+
/// code.
3839
///
39-
/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
40+
/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
4041
///
4142
/// will be lowered to
4243
///
4344
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
4445
/// ^bb0(%current: f32):
45-
/// %cmp = arith.cmpf "ogt", %current, %fval : f32
46-
/// %new_value = select %cmp, %current, %fval : f32
47-
/// memref.atomic_yield %new_value : f32
46+
/// %1 = arith.maximumf %current, %fval : f32
47+
/// memref.atomic_yield %1 : f32
4848
/// }
4949
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
5050
public:
5151
using OpRewritePattern::OpRewritePattern;
5252

5353
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
5454
PatternRewriter &rewriter) const final {
55-
arith::CmpFPredicate predicate;
56-
switch (op.getKind()) {
57-
case arith::AtomicRMWKind::maximumf:
58-
predicate = arith::CmpFPredicate::OGT;
59-
break;
60-
case arith::AtomicRMWKind::minimumf:
61-
predicate = arith::CmpFPredicate::OLT;
62-
break;
63-
default:
64-
return failure();
65-
}
66-
6755
auto loc = op.getLoc();
6856
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
6957
loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
7260

7361
Value lhs = genericOp.getCurrentValue();
7462
Value rhs = op.getValue();
75-
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
76-
Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
77-
bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
63+
64+
Value arithOp =
65+
mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
66+
bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
7867

7968
rewriter.replaceOp(op, genericOp.getResult());
8069
return success();

mlir/test/Dialect/MemRef/expand-ops.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
55
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
66
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7+
%y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
78
return %x : f32
89
}
9-
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
10+
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
1011
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
11-
// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
12-
// CHECK: [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
13-
// CHECK: memref.atomic_yield [[SELECT]] : f32
12+
// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
13+
// CHECK: memref.atomic_yield [[MAXIMUM]] : f32
1414
// CHECK: }
15-
// CHECK: return %0 : f32
15+
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
16+
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
17+
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
18+
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
19+
// CHECK: }
20+
// CHECK: return [[RESULT]] : f32
1621

1722
// -----
1823

0 commit comments

Comments
 (0)