Skip to content

[mlir][memref] Use dedicated ops in AtomicRMWOpConverter #66437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 14, 2023

Conversation

unterumarmung
Copy link
Contributor

This patch refactors the AtomicRMWOpConverter class to use
the dedicated operations from Arith dialect instead of using
cmpf + select pattern.
Also, a test for minimumf kind of atomic_rmw has been added.

This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.
@unterumarmung unterumarmung requested a review from a team September 14, 2023 21:29
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:memref labels Sep 14, 2023
@unterumarmung unterumarmung requested review from a team September 14, 2023 21:29
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-core

Changes This patch refactors the `AtomicRMWOpConverter` class to use the dedicated operations from Arith dialect instead of using `cmpf` + `select` pattern. Also, a test for `minimumf` kind of `atomic_rmw` has been added.

--
Full diff: https://github.com/llvm/llvm-project/pull/66437.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+10-21)
  • (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+10-5)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 7c3ca19b789c750..b3beaada2539dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -33,18 +33,18 @@ using namespace mlir;
 namespace {
 
 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
-/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
-/// `memref.generic_atomic_rmw` with the expanded code.
+/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
+/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
+/// code.
 ///
-/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
 ///
 /// will be lowered to
 ///
 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
 /// ^bb0(%current: f32):
-///   %cmp = arith.cmpf "ogt", %current, %fval : f32
-///   %new_value = select %cmp, %current, %fval : f32
-///   memref.atomic_yield %new_value : f32
+///   %1 = arith.maximumf %current, %fval : f32
+///   memref.atomic_yield %1 : f32
 /// }
 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 public:
@@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
                                 PatternRewriter &rewriter) const final {
-    arith::CmpFPredicate predicate;
-    switch (op.getKind()) {
-    case arith::AtomicRMWKind::maximumf:
-      predicate = arith::CmpFPredicate::OGT;
-      break;
-    case arith::AtomicRMWKind::minimumf:
-      predicate = arith::CmpFPredicate::OLT;
-      break;
-    default:
-      return failure();
-    }
-
     auto loc = op.getLoc();
     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
         loc, op.getMemref(), op.getIndices());
@@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
 
     Value lhs = genericOp.getCurrentValue();
     Value rhs = op.getValue();
-    Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
-    Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
-    bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
+
+    Value arithOp =
+        mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
+    bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
 
     rewriter.replaceOp(op, genericOp.getResult());
     return success();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 3234b35e99dcdfe..6c98cf978505334 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -4,15 +4,20 @@
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
   %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
   return %x : f32
 }
-// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
-// CHECK:   [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
-// CHECK:   [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32
-// CHECK:   memref.atomic_yield [[SELECT]] : f32
+// CHECK:   [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXIMUM]] : f32
 // CHECK: }
-// CHECK: return %0 : f32
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINIMUM]] : f32
+// CHECK: }
+// CHECK: return [[RESULT]] : f32
 
 // -----
 

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@unterumarmung unterumarmung merged commit 6f4a528 into llvm:main Sep 14, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
This patch refactors the `AtomicRMWOpConverter` class to use
the dedicated operations from Arith dialect instead of using
`cmpf` + `select` pattern.
Also, a test for `minimumf` kind of `atomic_rmw` has been added.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants