Skip to content

[mlir] Add maxnumf and minnumf to AtomicRMWKind #66442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 15, 2023

Conversation

unterumarmung
Copy link
Contributor

This commit adds the mentioned kinds of AtomicRMWKind
as well as code generation for them.

This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
@unterumarmung unterumarmung requested review from a team September 14, 2023 22:32
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Changes This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.

--
Full diff: https://github.com/llvm/llvm-project/pull/66442.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithBase.td (+3-1)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp (+5-2)
  • (modified) mlir/test/Dialect/MemRef/expand-ops.mlir (+15-3)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index a833e9c8220af5b..133af893e4efa74 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF     : I64EnumAttrCase<"mulf", 9>;
 def ATOMIC_RMW_KIND_MULI     : I64EnumAttrCase<"muli", 10>;
 def ATOMIC_RMW_KIND_ORI      : I64EnumAttrCase<"ori", 11>;
 def ATOMIC_RMW_KIND_ANDI     : I64EnumAttrCase<"andi", 12>;
+def ATOMIC_RMW_KIND_MAXNUMF  : I64EnumAttrCase<"maxnumf", 13>;
+def ATOMIC_RMW_KIND_MINNUMF  : I64EnumAttrCase<"minnumf", 14>;
 
 def AtomicRMWKindAttr : I64EnumAttr<
     "AtomicRMWKind", "",
@@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr<
      ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
      ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
      ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
-     ATOMIC_RMW_KIND_ANDI]> {
+     ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
   let cppNamespace = "::mlir::arith";
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d39c5b6051122e4..ae8a6ef350ce191 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
+   case AtomicRMWKind::maxnumf:
+    return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
+  case AtomicRMWKind::minnumf:
+    return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::maxs:
     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
   case AtomicRMWKind::mins:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index b3beaada2539dbc..faba12f5bf82f89 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace memref {
@@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
         [](memref::AtomicRMWOp op) {
-          return op.getKind() != arith::AtomicRMWKind::maximumf &&
-                 op.getKind() != arith::AtomicRMWKind::minimumf;
+          constexpr std::array shouldBeExpandedKinds = {
+              arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
+              arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
+          return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
         });
     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
       return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index 6c98cf978505334..f958a92b751a4ab 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -3,9 +3,11 @@
 // CHECK-LABEL: func @atomic_rmw_to_generic
 // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
 func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
-  %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
-  %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
-  return %x : f32
+  %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+  return %a : f32
 }
 // CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
 // CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
 // CHECK:   [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
 // CHECK:   memref.atomic_yield [[MINIMUM]] : f32
 // CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MAXNUM]] : f32
+// CHECK: }
+// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
+// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
+// CHECK:   [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
+// CHECK:   memref.atomic_yield [[MINNUM]] : f32
+// CHECK: }
 // CHECK: return [[RESULT]] : f32
 
 // -----

@unterumarmung unterumarmung merged commit 01e80a0 into llvm:main Sep 15, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
This commit adds the mentioned kinds of `AtomicRMWKind`
as well as code generation for them.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants