diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index 7c3ca19b789c7..b3beaada2539d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -33,18 +33,18 @@ using namespace mlir; namespace { /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with -/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to -/// `memref.generic_atomic_rmw` with the expanded code. +/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for +/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded +/// code. /// -/// %x = atomic_rmw "maximumf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 +/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32 /// /// will be lowered to /// /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> { /// ^bb0(%current: f32): -/// %cmp = arith.cmpf "ogt", %current, %fval : f32 -/// %new_value = select %cmp, %current, %fval : f32 -/// memref.atomic_yield %new_value : f32 +/// %1 = arith.maximumf %current, %fval : f32 +/// memref.atomic_yield %1 : f32 /// } struct AtomicRMWOpConverter : public OpRewritePattern { public: @@ -52,18 +52,6 @@ struct AtomicRMWOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(memref::AtomicRMWOp op, PatternRewriter &rewriter) const final { - arith::CmpFPredicate predicate; - switch (op.getKind()) { - case arith::AtomicRMWKind::maximumf: - predicate = arith::CmpFPredicate::OGT; - break; - case arith::AtomicRMWKind::minimumf: - predicate = arith::CmpFPredicate::OLT; - break; - default: - return failure(); - } - auto loc = op.getLoc(); auto genericOp = rewriter.create( loc, op.getMemref(), op.getIndices()); @@ -72,9 +60,10 @@ struct AtomicRMWOpConverter : public OpRewritePattern { Value lhs = genericOp.getCurrentValue(); Value rhs = op.getValue(); - Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); - Value select = bodyBuilder.create(loc, cmp, lhs, rhs); - bodyBuilder.create(loc, select); + + Value arithOp = + mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs); + bodyBuilder.create(loc, arithOp); rewriter.replaceOp(op, genericOp.getResult()); return success(); diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir index 3234b35e99dcd..6c98cf9785053 100644 --- a/mlir/test/Dialect/MemRef/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -4,15 +4,20 @@ // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { %x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 + %y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32 return %x : f32 } -// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { // CHECK: ^bb0([[CUR_VAL:%.*]]: f32): -// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32 -// CHECK: [[SELECT:%.*]] = arith.select [[CMP]], [[CUR_VAL]], [[f]] : f32 -// CHECK: memref.atomic_yield [[SELECT]] : f32 +// CHECK: [[MAXIMUM:%.*]] = arith.maximumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MAXIMUM]] : f32 // CHECK: } -// CHECK: return %0 : f32 +// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32 +// CHECK: memref.atomic_yield [[MINIMUM]] : f32 +// CHECK: } +// CHECK: return [[RESULT]] : f32 // -----