Skip to content

[mlir][sparse] Replace getSparseTensorType with tryGetSparseTensorType #109435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSource();
const auto srcTp = getSparseTensorType(srcTensor);
const auto dstTp = getSparseTensorType(op.getResult());
const auto srcTp = tryGetSparseTensorType(srcTensor);
const auto dstTp = tryGetSparseTensorType(op.getResult());
if (!srcTp || !dstTp)
return failure();

if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
!dstTp.hasStaticDimShape())
if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
!dstTp->hasStaticDimShape())
return failure();

SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
SmallVector<Value> dstSizes;
for (Dimension d : dstTp.getDimShape())
for (Dimension d : dstTp->getDimShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));

Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
// Only need an unordered COO buffer if input and output are not sorted
// in the same way.
Type bufferTp = getBufferType(
dstTp.withoutDimToLvl(),
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
dstTp->withoutDimToLvl(),
!srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
SmallVector<Value> dynSizes;
Value buffer = rewriter
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
Expand All @@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
// followed by an optional
// %t = sparse_tensor.cast %tmp
// depending on whether the input/output are sorted in the same way.
const auto encSrc = srcTp.getEncoding();
const auto encSrc = srcTp->getEncoding();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
const Dimension srcRank = srcTp.getDimRank();
const Dimension srcRank = srcTp->getDimRank();
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
Expand All @@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
collapsedSizes, collapsedDcvs);

ReassociationIndices expandIdx;
for (Dimension i = 0; i < dstTp.getDimRank(); i++)
for (Dimension i = 0; i < dstTp->getDimRank(); i++)
expandIdx.push_back(i);
SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
SmallVector<Value> dstDcvs;
Expand All @@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
});

Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
if (bufferTp != dstTp) {
auto dstRTT = dstTp.getRankedTensorType();
if (bufferTp != *dstTp) {
auto dstRTT = dstTp->getRankedTensorType();
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
t = converted;
Expand Down Expand Up @@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
std::optional<int64_t> dim = op.getConstantIndex();
auto stt = getSparseTensorType(op.getSource());
if (!dim || !stt.hasEncoding())
auto stt = tryGetSparseTensorType(op.getSource());
if (!dim || !stt || !stt->hasEncoding())
return failure();

if (stt.isPermutation()) {
if (stt->isPermutation()) {
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
toLvl(stt.getEncoding(), *dim));
toLvl(stt->getEncoding(), *dim));
return success();
}

Expand All @@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
// computed simply by lvl_size * block_size.
Location loc = op.getLoc();
SmallVector<Value> maxLvlCrds;
for (Level l = 0; l < stt.getLvlRank(); l++) {
for (Level l = 0; l < stt->getLvlRank(); l++) {
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
maxLvlCrds.push_back(maxLvlCrd);
}

AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
maxLvlCrds);

Value dimSz = rewriter.create<arith::AddIOp>(
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CooPN
%0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<?x?xf32, #CooPNo>
return %0 : tensor<?x?xf32, #CooPNo>
}

// CHECK-LABEL: func.func @test_tensor_dim_unranked
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codegen.mlir is a test that very specifically tests the codegen path

the added tests are really regression tests that test crash-before, no-crash after behavior
as such, can you please put them in a new file and document that

e.g. no_lowering.mlir (or better name)
that way you can also reduce the flags in the runner

// CHECK: tensor.dim
func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index {
%c = arith.constant 0 : index
%0 = tensor.dim %arg0, %c : tensor<*xf32>
return %0 : index
}

// CHECK-LABEL: func.func @test_tensor_reshape_unranked
// CHECK: tensor.reshape
func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
%dst = tensor.reshape %src(%shape)
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
return %dst : tensor<?xf32>
}
Loading