-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] Add maxnumf
and minnumf
to AtomicRMWKind
#66442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir ChangesThis commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.-- 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td index a833e9c8220af5b..133af893e4efa74 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -82,6 +82,8 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; +def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>; +def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>; def AtomicRMWKindAttr : I64EnumAttr< "AtomicRMWKind", "", @@ -89,7 +91,7 @@ def AtomicRMWKindAttr : I64EnumAttr< ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, - ATOMIC_RMW_KIND_ANDI]> { + ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> { let cppNamespace = "::mlir::arith"; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d39c5b6051122e4..ae8a6ef350ce191 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return builder.create<arith::MaximumFOp>(loc, lhs, rhs); case AtomicRMWKind::minimumf: return builder.create<arith::MinimumFOp>(loc, lhs, rhs); + case AtomicRMWKind::maxnumf: + return builder.create<arith::MaxNumFOp>(loc, lhs, rhs); + case AtomicRMWKind::minnumf: + return builder.create<arith::MinNumFOp>(loc, lhs, rhs); case AtomicRMWKind::maxs: return builder.create<arith::MaxSIOp>(loc, lhs, rhs); case AtomicRMWKind::mins: diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index b3beaada2539dbc..faba12f5bf82f89 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace memref { @@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> { target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>(); target.addDynamicallyLegalOp<memref::AtomicRMWOp>( [](memref::AtomicRMWOp op) { - return op.getKind() != arith::AtomicRMWKind::maximumf && - op.getKind() != arith::AtomicRMWKind::minimumf; + constexpr std::array shouldBeExpandedKinds = { + arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf, + arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf}; + return !llvm::is_contained(shouldBeExpandedKinds, op.getKind()); }); target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) { return !cast<MemRefType>(op.getShape().getType()).hasStaticShape(); diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir index 6c98cf978505334..f958a92b751a4ab 100644 --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -3,9 +3,11 @@ // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { - %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 - return %x : f32 + %a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %a : f32 } // CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): @@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 // CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32 // CHECK: memref.atomic_yield [[MINIMUM]] : f32 // CHECK: } +// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MAXNUM]] : f32 +// CHECK: } +// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MINNUM]] : f32 +// CHECK: } // CHECK: return [[RESULT]] : f32 // ----- |
dcaballe
approved these changes
Sep 15, 2023
ZijunZhaoCCK
pushed a commit
to ZijunZhaoCCK/llvm-project
that referenced
this pull request
Sep 19, 2023
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
zahiraam
pushed a commit
to tahonermann/llvm-project
that referenced
this pull request
Oct 24, 2023
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
zahiraam
pushed a commit
to tahonermann/llvm-project
that referenced
this pull request
Oct 24, 2023
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This commit adds the mentioned kinds of
AtomicRMWKind
as well as code generation for them.