-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][sparse] Replace getSparseTensorType
with tryGetSparseTensorType
#109435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…Type` This PR fixes a bug in `SparseTensorDimOpRewriter` when `tensor.dim` has an unranked tensor type. To prevent crashes, we now use `tryGetSparseTensorType` instead of `getSparseTensorType`.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a bug in Full diff: https://github.com/llvm/llvm-project/pull/109435.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index cc372ed1be6217..60db71d96547fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSource();
- const auto srcTp = getSparseTensorType(srcTensor);
- const auto dstTp = getSparseTensorType(op.getResult());
+ const auto srcTp = tryGetSparseTensorType(srcTensor);
+ const auto dstTp = tryGetSparseTensorType(op.getResult());
+ if (!srcTp || !dstTp)
+ return failure();
- if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
- !dstTp.hasStaticDimShape())
+ if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
+ !dstTp->hasStaticDimShape())
return failure();
SmallVector<Value> srcSizes;
- sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
+ sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
SmallVector<Value> dstSizes;
- for (Dimension d : dstTp.getDimShape())
+ for (Dimension d : dstTp->getDimShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
// Only need an unordered COO buffer if input and output are not sorted
// in the same way.
Type bufferTp = getBufferType(
- dstTp.withoutDimToLvl(),
- !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
+ dstTp->withoutDimToLvl(),
+ !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
SmallVector<Value> dynSizes;
Value buffer = rewriter
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
@@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
// followed by an optional
// %t = sparse_tensor.cast %tmp
// depending on whether the input/output are sorted in the same way.
- const auto encSrc = srcTp.getEncoding();
+ const auto encSrc = srcTp->getEncoding();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
- const Dimension srcRank = srcTp.getDimRank();
+ const Dimension srcRank = srcTp->getDimRank();
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
@@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
collapsedSizes, collapsedDcvs);
ReassociationIndices expandIdx;
- for (Dimension i = 0; i < dstTp.getDimRank(); i++)
+ for (Dimension i = 0; i < dstTp->getDimRank(); i++)
expandIdx.push_back(i);
SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
SmallVector<Value> dstDcvs;
@@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
});
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
- if (bufferTp != dstTp) {
- auto dstRTT = dstTp.getRankedTensorType();
+ if (bufferTp != *dstTp) {
+ auto dstRTT = dstTp->getRankedTensorType();
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
t = converted;
@@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
std::optional<int64_t> dim = op.getConstantIndex();
- auto stt = getSparseTensorType(op.getSource());
- if (!dim || !stt.hasEncoding())
+ auto stt = tryGetSparseTensorType(op.getSource());
+ if (!dim || !stt || !stt->hasEncoding())
return failure();
- if (stt.isPermutation()) {
+ if (stt->isPermutation()) {
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
- toLvl(stt.getEncoding(), *dim));
+ toLvl(stt->getEncoding(), *dim));
return success();
}
@@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
// computed simply by lvl_size * block_size.
Location loc = op.getLoc();
SmallVector<Value> maxLvlCrds;
- for (Level l = 0; l < stt.getLvlRank(); l++) {
+ for (Level l = 0; l < stt->getLvlRank(); l++) {
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
maxLvlCrds.push_back(maxLvlCrd);
}
- AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
+ AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
+ op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
maxLvlCrds);
Value dimSz = rewriter.create<arith::AddIOp>(
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index af78458f109329..df03d871ba3a3e 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
return %0 : tensor<?x?xf32, #CooPNo>
}
+
+// CHECK-LABEL: func.func @test_tensor_dim_unranked
+// CHECK: tensor.dim
+func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
+ %c = arith.constant 0 : index
+ %0 = tensor.dim %arg0, %c : tensor<*xf32>
+ return %0 : index
+}
+
+// CHECK-LABEL: func.func @test_tensor_reshape_unranked
+// CHECK: tensor.reshape
+func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
+ %dst = tensor.reshape %src(%shape)
+ : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
+ return %dst : tensor<?xf32>
+}
|
I don't think sparse compiler support unranked tensor. By nature, it seems impossible to define the sparse layout without knowing the rank. |
Thanks for the response. You’re correct that sparse tensors are ranked. However, to maintain the robustness of this pattern, the compiler should handle invalid inputs, such as
|
Ping~ |
I think we need a better way to specify those reused (yet sparsifiable) tensor operations, which is out of the scope of the PR. Hence LGTM. |
Okay, thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I started a review and never pressed the button. I had one request originally. Do you mind incorporating this post-commit? If not, not a big deal
@@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN | |||
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo> | |||
return %0 : tensor<?x?xf32, #CooPNo> | |||
} | |||
|
|||
// CHECK-LABEL: func.func @test_tensor_dim_unranked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
codegen.mlir is a test that very specifically tests the codegen path
the added tests are really regression tests that test crash-before, no-crash after behavior
as such, can you please put them in a new file and document that
e.g. no_lowering.mlir (or better name)
that way you can also reduce the flags in the runner
Thanks for your review. I’m happy to do that. I’ll open a new PR to address the tests once my holiday is over. |
This PR relocates the tests added in llvm#109435 to a new file named `no_lowering.mlir` and adds some new tests.
This PR relocates the tests added in llvm#109435 to a new file named `no_lowering.mlir` and adds some new tests.
) This PR relocates the tests added in #109435 to a new file named `no_lowering.mlir` and adds some new tests.
…#110976) This PR relocates the tests added in llvm#109435 to a new file named `no_lowering.mlir` and adds some new tests.
This PR fixes a bug in
SparseTensorDimOpRewriter
whentensor.dim
has an unranked tensor type. To prevent crashes, we now usetryGetSparseTensorType
instead ofgetSparseTensorType
. Fixes #107807.