Skip to content

[mlir][sparse] deallocate tmp coo buffer generated during stage-spars… #82017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 17, 2024

Conversation

PeimingLiu
Copy link
Member

…e-ops pass.

@llvmbot
Copy link
Member

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

…e-ops pass.


Full diff: https://github.com/llvm/llvm-project/pull/82017.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h (+2-3)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td (+3-2)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp (+11-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp (+11-2)
  • (modified) mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir (+2-1)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
index ebbc522123a599..c0f31762ee071f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h
@@ -1,5 +1,4 @@
-//===- SparseTensorInterfaces.h - sparse tensor operations
-//interfaces-------===//
+//===- SparseTensorInterfaces.h - sparse tensor operations interfaces------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -20,7 +19,7 @@ class StageWithSortSparseOp;
 
 namespace detail {
 LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
-                                PatternRewriter &rewriter);
+                                PatternRewriter &rewriter, Value &tmpBufs);
 } // namespace detail
 } // namespace sparse_tensor
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
index 1379363ff75f42..05eed0483f2c8a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td
@@ -34,9 +34,10 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
     /*desc=*/"Stage the operation, return the final result value after staging.",
     /*retTy=*/"::mlir::LogicalResult",
     /*methodName=*/"stageWithSort",
-    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
+    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter,
+                  "Value &":$tmpBuf),
     /*methodBody=*/[{
-        return detail::stageWithSortImpl($_op, rewriter);
+        return detail::stageWithSortImpl($_op, rewriter, tmpBuf);
     }]>,
   ];
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index d33eb9d2877ae3..4866971af08e7d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -16,9 +16,8 @@ using namespace mlir::sparse_tensor;
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
 
-LogicalResult
-sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
-                                         PatternRewriter &rewriter) {
+LogicalResult sparse_tensor::detail::stageWithSortImpl(
+    StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
   if (!op.needsExtraSort())
     return failure();
 
@@ -44,9 +43,15 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
     rewriter.replaceOp(op, dstCOO);
   } else {
     // Need an extra conversion if the target type is not COO.
-    rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+    auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
+    rewriter.setInsertionPointAfter(c);
+    // Informs the caller about the intermediate buffer we allocated. We can not
+    // create a bufferization::DeallocateTensorOp here because it would
+    // introduce cyclic dependency between the SparseTensorDialect and the
+    // BufferizationDialect. Besides, whether the buffer need to be deallocated
+    // by SparseTensorDialect or by BufferDeallocationPass is still TBD.
+    tmpBufs = dstCOO;
   }
-  // TODO: deallocate extra COOs, we should probably delegate it to buffer
-  // deallocation pass.
+
   return success();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
index 5875cd4f9fd9d1..992f4faafc099b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -21,8 +22,16 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
 
   LogicalResult matchAndRewrite(StageWithSortOp op,
                                 PatternRewriter &rewriter) const override {
-    return llvm::cast<StageWithSortSparseOp>(op.getOperation())
-        .stageWithSort(rewriter);
+    Location loc = op.getLoc();
+    Value tmpBuf = nullptr;
+    auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation());
+    LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf);
+    // Deallocate tmpBuf, maybe delegate to buffer deallocation pass in the
+    // future.
+    if (succeeded(stageResult) && tmpBuf)
+      rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf);
+
+    return stageResult;
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 96a1140372bd6c..83dbc9568c7a36 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -82,10 +82,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{
 // CHECK:             scf.if
 // CHECK:               tensor.insert
 // CHECK:           sparse_tensor.load
-// CHECK:           sparse_tensor.reorder_coo
+// CHECK:           %[[TMP:.*]] = sparse_tensor.reorder_coo
 // CHECK:           sparse_tensor.foreach
 // CHECK:             tensor.insert
 // CHECK:           sparse_tensor.load
+// CHECK:           bufferization.dealloc_tensor %[[TMP]]
 func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
   return %0 : tensor<?x?x?xf64, #SparseTensor>

@PeimingLiu PeimingLiu merged commit 11705af into llvm:main Feb 17, 2024
@PeimingLiu PeimingLiu deleted the fix-leak branch February 17, 2024 20:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants