-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Add bufferization for linalg.softmax
#97019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Add bufferization for linalg.softmax
#97019
Conversation
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesImplement the Full diff: https://github.com/llvm/llvm-project/pull/97019.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
+
+struct SoftmaxOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+ linalg::SoftmaxOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Output operand is not read.
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ return &opOperand == &softmaxOp.getInputMutable();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ FailureOr<Value> inputBuffer =
+ getBuffer(rewriter, softmaxOp.getInput(), options);
+ if (failed(inputBuffer))
+ return failure();
+ FailureOr<Value> outputBuffer =
+ getBuffer(rewriter, softmaxOp.getOutput(), options);
+ if (failed(outputBuffer))
+ return failure();
+ rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+ /*result=*/TypeRange(), *inputBuffer,
+ *outputBuffer, softmaxOp.getDimension());
+ replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
+
+ SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+// CHECK: %[[alloc:.*]] = memref.alloc()
+// CHECK-NOT: memref.copy
+// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK: return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2)
+ ins(%arg0 : tensor<2x16x32xf32>)
+ outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
|
@llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesImplement the Full diff: https://github.com/llvm/llvm-project/pull/97019.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
+
+struct SoftmaxOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+ linalg::SoftmaxOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Output operand is not read.
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ return &opOperand == &softmaxOp.getInputMutable();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ FailureOr<Value> inputBuffer =
+ getBuffer(rewriter, softmaxOp.getInput(), options);
+ if (failed(inputBuffer))
+ return failure();
+ FailureOr<Value> outputBuffer =
+ getBuffer(rewriter, softmaxOp.getOutput(), options);
+ if (failed(outputBuffer))
+ return failure();
+ rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+ /*result=*/TypeRange(), *inputBuffer,
+ *outputBuffer, softmaxOp.getDimension());
+ replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
+
+ SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+// CHECK: %[[alloc:.*]] = memref.alloc()
+// CHECK-NOT: memref.copy
+// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK: return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2)
+ ins(%arg0 : tensor<2x16x32xf32>)
+ outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
Implement the
BufferizableOpInterface
forlinalg.softmax
. The op is not aLinalgOp
, so it is not covered by the "catch all"LinalgOp
interface implementation.