Skip to content

[mlir][sparse] Replace getSparseTensorType with tryGetSparseTensorType #109435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2024

Conversation

CoTinker
Copy link
Contributor

This PR fixes a bug in SparseTensorDimOpRewriter when tensor.dim has an unranked tensor type. To prevent crashes, we now use tryGetSparseTensorType instead of getSparseTensorType. Fixes #107807.

…Type`

This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim`
has an unranked tensor type. To prevent crashes, we now use
`tryGetSparseTensorType` instead of `getSparseTensorType`.
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a bug in SparseTensorDimOpRewriter when tensor.dim has an unranked tensor type. To prevent crashes, we now use tryGetSparseTensorType instead of getSparseTensorType. Fixes #107807.


Full diff: https://github.com/llvm/llvm-project/pull/109435.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+22-20)
  • (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+16)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index cc372ed1be6217..60db71d96547fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value srcTensor = op.getSource();
-    const auto srcTp = getSparseTensorType(srcTensor);
-    const auto dstTp = getSparseTensorType(op.getResult());
+    const auto srcTp = tryGetSparseTensorType(srcTensor);
+    const auto dstTp = tryGetSparseTensorType(op.getResult());
+    if (!srcTp || !dstTp)
+      return failure();
 
-    if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
-        !dstTp.hasStaticDimShape())
+    if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
+        !dstTp->hasStaticDimShape())
       return failure();
 
     SmallVector<Value> srcSizes;
-    sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+    sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
     SmallVector<Value> dstSizes;
-    for (Dimension d : dstTp.getDimShape())
+    for (Dimension d : dstTp->getDimShape())
       dstSizes.push_back(constantIndex(rewriter, loc, d));
 
     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
     // Only need an unordered COO buffer if input and output are not sorted
     // in the same way.
     Type bufferTp = getBufferType(
-        dstTp.withoutDimToLvl(),
-        !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
+        dstTp->withoutDimToLvl(),
+        !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
     SmallVector<Value> dynSizes;
     Value buffer = rewriter
                        .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
     // followed by an optional
     //   %t = sparse_tensor.cast %tmp
     // depending on whether the input/output are sorted in the same way.
-    const auto encSrc = srcTp.getEncoding();
+    const auto encSrc = srcTp->getEncoding();
     ForeachOp foreachOp = rewriter.create<ForeachOp>(
         loc, srcTensor, buffer,
         [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
             ValueRange reduc) {
-          const Dimension srcRank = srcTp.getDimRank();
+          const Dimension srcRank = srcTp->getDimRank();
           SmallVector<Value> srcDcvs;
           srcDcvs.reserve(srcRank);
           for (Dimension d = 0; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
                      collapsedSizes, collapsedDcvs);
 
           ReassociationIndices expandIdx;
-          for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+          for (Dimension i = 0; i < dstTp->getDimRank(); i++)
             expandIdx.push_back(i);
           SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
           SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
         });
 
     Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
-    if (bufferTp != dstTp) {
-      auto dstRTT = dstTp.getRankedTensorType();
+    if (bufferTp != *dstTp) {
+      auto dstRTT = dstTp->getRankedTensorType();
       Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
       rewriter.create<DeallocTensorOp>(loc, t);
       t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
   LogicalResult matchAndRewrite(tensor::DimOp op,
                                 PatternRewriter &rewriter) const override {
     std::optional<int64_t> dim = op.getConstantIndex();
-    auto stt = getSparseTensorType(op.getSource());
-    if (!dim || !stt.hasEncoding())
+    auto stt = tryGetSparseTensorType(op.getSource());
+    if (!dim || !stt || !stt->hasEncoding())
       return failure();
 
-    if (stt.isPermutation()) {
+    if (stt->isPermutation()) {
       rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
-                                         toLvl(stt.getEncoding(), *dim));
+                                         toLvl(stt->getEncoding(), *dim));
       return success();
     }
 
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
     // computed simply by lvl_size * block_size.
     Location loc = op.getLoc();
     SmallVector<Value> maxLvlCrds;
-    for (Level l = 0; l < stt.getLvlRank(); l++) {
+    for (Level l = 0; l < stt->getLvlRank(); l++) {
       Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
       Value maxLvlCrd = rewriter.create<arith::SubIOp>(
           loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
       maxLvlCrds.push_back(maxLvlCrd);
     }
 
-    AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
+    AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
     Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
-        op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
+        op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
         maxLvlCrds);
 
     Value dimSz = rewriter.create<arith::AddIOp>(
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index af78458f109329..df03d871ba3a3e 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
   %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
   return %0 : tensor<?x?xf32, #CooPNo>
 }
+
+// CHECK-LABEL: func.func @test_tensor_dim_unranked
+//       CHECK: tensor.dim
+func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
+  %c = arith.constant 0 : index
+  %0 = tensor.dim %arg0, %c : tensor<*xf32>
+  return %0 : index
+}
+
+// CHECK-LABEL: func.func @test_tensor_reshape_unranked
+//       CHECK: tensor.reshape
+func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
+  %dst = tensor.reshape %src(%shape)
+         : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+  return %dst : tensor<?xf32>
+}

@PeimingLiu
Copy link
Member

I don't think sparse compiler support unranked tensor. By nature, it seems impossible to define the sparse layout without knowing the rank.

@CoTinker
Copy link
Contributor Author

Thanks for the response. You’re correct that sparse tensors are ranked. However, to maintain the robustness of this pattern, the compiler should handle invalid inputs, such as unranked tensors, by returning failure() gracefully, rather than crashing.
Such as:

auto stt = tryGetSparseTensorType(op.getSource());
if (!dim || !stt || !stt->hasEncoding())
   return failure();

@CoTinker
Copy link
Contributor Author

Ping~

@PeimingLiu
Copy link
Member

I think we need a better way to specify those reused (yet sparsifiable) tensor operations, which is out of the scope of the PR. Hence LGTM.

@CoTinker
Copy link
Contributor Author

Okay, thanks.

@CoTinker CoTinker merged commit 129ade2 into llvm:main Sep 30, 2024
11 checks passed
@CoTinker CoTinker deleted the lower_sparse branch September 30, 2024 01:17
Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I started a review and never pressed the button. I had one request originally. Do you mind incorporating this post-commit? If not, not a big deal

@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
return %0 : tensor<?x?xf32, #CooPNo>
}

// CHECK-LABEL: func.func @test_tensor_dim_unranked
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codegen.mlir is a test that very specifically tests the codegen path

the added tests are really regression tests that test crash-before, no-crash after behavior
as such, can you please put them in a new file and document that

e.g. no_lowering.mlir (or better name)
that way you can also reduce the flags in the runner

@CoTinker
Copy link
Contributor Author

CoTinker commented Oct 2, 2024

Oops, I started a review and never pressed the button. I had one request originally. Do you mind incorporating this post-commit? If not, not a big deal

Thanks for your review. I’m happy to do that. I’ll open a new PR to address the tests once my holiday is over.

CoTinker added a commit to CoTinker/llvm-project that referenced this pull request Oct 3, 2024
This PR relocates the tests added in llvm#109435 to a new file named
`no_lowering.mlir` and adds some new tests.
CoTinker added a commit to CoTinker/llvm-project that referenced this pull request Oct 9, 2024
This PR relocates the tests added in llvm#109435 to a new file named
`no_lowering.mlir` and adds some new tests.
CoTinker added a commit that referenced this pull request Oct 11, 2024
)

This PR relocates the tests added in #109435 to a new file named
`no_lowering.mlir` and adds some new tests.
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
…#110976)

This PR relocates the tests added in llvm#109435 to a new file named
`no_lowering.mlir` and adds some new tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
4 participants