Skip to content

Commit 11705af

Browse files
authored
[mlir][sparse] deallocate tmp coo buffer generated during stage-spars… (#82017)
…e-ops pass.
1 parent 164055f commit 11705af

File tree

5 files changed

+35
-14
lines changed

5 files changed

+35
-14
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- SparseTensorInterfaces.h - sparse tensor operations
2-
//interfaces-------===//
1+
//===- SparseTensorInterfaces.h - sparse tensor operations interfaces------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -20,7 +19,7 @@ class StageWithSortSparseOp;
2019

2120
namespace detail {
2221
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
23-
PatternRewriter &rewriter);
22+
PatternRewriter &rewriter, Value &tmpBufs);
2423
} // namespace detail
2524
} // namespace sparse_tensor
2625
} // namespace mlir

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
3434
/*desc=*/"Stage the operation, return the final result value after staging.",
3535
/*retTy=*/"::mlir::LogicalResult",
3636
/*methodName=*/"stageWithSort",
37-
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
37+
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter,
38+
"Value &":$tmpBuf),
3839
/*methodBody=*/[{
39-
return detail::stageWithSortImpl($_op, rewriter);
40+
return detail::stageWithSortImpl($_op, rewriter, tmpBuf);
4041
}]>,
4142
];
4243
}

mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ using namespace mlir::sparse_tensor;
1616

1717
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
1818

19-
LogicalResult
20-
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
21-
PatternRewriter &rewriter) {
19+
/// Stage the operations into a sequence of simple operations as follow:
20+
/// op -> unsorted_coo +
21+
/// unsorted_coo -> sorted_coo +
22+
/// sorted_coo -> dstTp.
23+
///
24+
/// return `tmpBuf` if a intermediate memory is allocated.
25+
LogicalResult sparse_tensor::detail::stageWithSortImpl(
26+
StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
2227
if (!op.needsExtraSort())
2328
return failure();
2429

@@ -44,9 +49,15 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
4449
rewriter.replaceOp(op, dstCOO);
4550
} else {
4651
// Need an extra conversion if the target type is not COO.
47-
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
52+
auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
53+
rewriter.setInsertionPointAfter(c);
54+
// Informs the caller about the intermediate buffer we allocated. We can not
55+
// create a bufferization::DeallocateTensorOp here because it would
56+
// introduce cyclic dependency between the SparseTensorDialect and the
57+
// BufferizationDialect. Besides, whether the buffer need to be deallocated
58+
// by SparseTensorDialect or by BufferDeallocationPass is still TBD.
59+
tmpBufs = dstCOO;
4860
}
49-
// TODO: deallocate extra COOs, we should probably delegate it to buffer
50-
// deallocation pass.
61+
5162
return success();
5263
}

mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
910
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1011
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
1112
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -21,8 +22,16 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
2122

2223
LogicalResult matchAndRewrite(StageWithSortOp op,
2324
PatternRewriter &rewriter) const override {
24-
return llvm::cast<StageWithSortSparseOp>(op.getOperation())
25-
.stageWithSort(rewriter);
25+
Location loc = op.getLoc();
26+
Value tmpBuf = nullptr;
27+
auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
28+
LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
29+
// Deallocate tmpBuf.
30+
// TODO: Delegate to buffer deallocation pass in the future.
31+
if (succeeded(stageResult) && tmpBuf)
32+
rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
33+
34+
return stageResult;
2635
}
2736
};
2837
} // namespace

mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
8282
// CHECK: scf.if
8383
// CHECK: tensor.insert
8484
// CHECK: sparse_tensor.load
85-
// CHECK: sparse_tensor.reorder_coo
85+
// CHECK: %[[TMP:.*]] = sparse_tensor.reorder_coo
8686
// CHECK: sparse_tensor.foreach
8787
// CHECK: tensor.insert
8888
// CHECK: sparse_tensor.load
89+
// CHECK: bufferization.dealloc_tensor %[[TMP]]
8990
func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
9091
%0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
9192
return %0 : tensor<?x?x?xf64, #SparseTensor>

0 commit comments

Comments
 (0)