diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 81a5398dabcb7..be158af09d398 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper { (Ops::template attachInterface>(*ctx), ...); } }; + +struct SoftmaxOpInterface + : public DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Output operand is not read. + auto softmaxOp = cast(op); + return &opOperand == &softmaxOp.getInputMutable(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto softmaxOp = cast(op); + FailureOr inputBuffer = + getBuffer(rewriter, softmaxOp.getInput(), options); + if (failed(inputBuffer)) + return failure(); + FailureOr outputBuffer = + getBuffer(rewriter, softmaxOp.getOutput(), options); + if (failed(outputBuffer)) + return failure(); + rewriter.create(softmaxOp.getLoc(), + /*result=*/TypeRange(), *inputBuffer, + *outputBuffer, softmaxOp.getDimension()); + replaceOpWithBufferizedValues(rewriter, op, *outputBuffer); + return success(); + } +}; } // namespace void mlir::linalg::registerBufferizableOpInterfaceExternalModels( @@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >::registerOpInterface(ctx); + + SoftmaxOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index e8ab1184b1fd2..f416cd9fcf0b2 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor) -> tensor { // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref // CHECK: return %[[OUT_TENSOR]] } + +// ----- + +// CHECK-LABEL: func @bufferize_softmax( +// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32> +// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]] +// CHECK: %[[alloc:.*]] = memref.alloc() +// CHECK-NOT: memref.copy +// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}}) +// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]] +// CHECK: return %[[result]] +func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %1 = linalg.softmax dimension(2) + ins(%arg0 : tensor<2x16x32xf32>) + outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + return %1 : tensor<2x16x32xf32> +}