diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index cc372ed1be621..60db71d96547f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -881,25 +881,27 @@ struct TensorReshapeRewriter : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value srcTensor = op.getSource(); - const auto srcTp = getSparseTensorType(srcTensor); - const auto dstTp = getSparseTensorType(op.getResult()); + const auto srcTp = tryGetSparseTensorType(srcTensor); + const auto dstTp = tryGetSparseTensorType(op.getResult()); + if (!srcTp || !dstTp) + return failure(); - if (!srcTp.hasEncoding() || !dstTp.hasEncoding() || - !dstTp.hasStaticDimShape()) + if (!srcTp->hasEncoding() || !dstTp->hasEncoding() || + !dstTp->hasStaticDimShape()) return failure(); SmallVector srcSizes; - sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); + sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor); SmallVector dstSizes; - for (Dimension d : dstTp.getDimShape()) + for (Dimension d : dstTp->getDimShape()) dstSizes.push_back(constantIndex(rewriter, loc, d)); Value nnz = rewriter.create(loc, srcTensor); // Only need an unordered COO buffer if input and output are not sorted // in the same way. Type bufferTp = getBufferType( - dstTp.withoutDimToLvl(), - !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); + dstTp->withoutDimToLvl(), + !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); SmallVector dynSizes; Value buffer = rewriter .create(loc, bufferTp, dynSizes, Value(), @@ -917,12 +919,12 @@ struct TensorReshapeRewriter : public OpRewritePattern { // followed by an optional // %t = sparse_tensor.cast %tmp // depending on whether the input/output are sorted in the same way. - const auto encSrc = srcTp.getEncoding(); + const auto encSrc = srcTp->getEncoding(); ForeachOp foreachOp = rewriter.create( loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { - const Dimension srcRank = srcTp.getDimRank(); + const Dimension srcRank = srcTp->getDimRank(); SmallVector srcDcvs; srcDcvs.reserve(srcRank); for (Dimension d = 0; d < srcRank; d++) { @@ -945,7 +947,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { collapsedSizes, collapsedDcvs); ReassociationIndices expandIdx; - for (Dimension i = 0; i < dstTp.getDimRank(); i++) + for (Dimension i = 0; i < dstTp->getDimRank(); i++) expandIdx.push_back(i); SmallVector expandReass = {expandIdx}; SmallVector dstDcvs; @@ -958,8 +960,8 @@ struct TensorReshapeRewriter : public OpRewritePattern { }); Value t = rewriter.create(loc, foreachOp.getResult(0), true); - if (bufferTp != dstTp) { - auto dstRTT = dstTp.getRankedTensorType(); + if (bufferTp != *dstTp) { + auto dstRTT = dstTp->getRankedTensorType(); Value converted = rewriter.create(loc, dstRTT, t).getResult(); rewriter.create(loc, t); t = converted; @@ -1139,13 +1141,13 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::DimOp op, PatternRewriter &rewriter) const override { std::optional dim = op.getConstantIndex(); - auto stt = getSparseTensorType(op.getSource()); - if (!dim || !stt.hasEncoding()) + auto stt = tryGetSparseTensorType(op.getSource()); + if (!dim || !stt || !stt->hasEncoding()) return failure(); - if (stt.isPermutation()) { + if (stt->isPermutation()) { rewriter.replaceOpWithNewOp(op, op.getSource(), - toLvl(stt.getEncoding(), *dim)); + toLvl(stt->getEncoding(), *dim)); return success(); } @@ -1157,16 +1159,16 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern { // computed simply by lvl_size * block_size. Location loc = op.getLoc(); SmallVector maxLvlCrds; - for (Level l = 0; l < stt.getLvlRank(); l++) { + for (Level l = 0; l < stt->getLvlRank(); l++) { Value lvlSz = rewriter.create(loc, op.getSource(), l); Value maxLvlCrd = rewriter.create( loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType())); maxLvlCrds.push_back(maxLvlCrd); } - AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim); + AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim); Value maxDimCrd = rewriter.create( - op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp), + op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp), maxLvlCrds); Value dimSz = rewriter.create( diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index af78458f10932..df03d871ba3a3 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -826,3 +826,19 @@ func.func @sparse_new_coo_permute_no(%arg0: !llvm.ptr) -> tensor return %0 : tensor } + +// CHECK-LABEL: func.func @test_tensor_dim_unranked +// CHECK: tensor.dim +func.func @test_tensor_dim_unranked(%arg0: tensor<*xf32>) -> index { + %c = arith.constant 0 : index + %0 = tensor.dim %arg0, %c : tensor<*xf32> + return %0 : index +} + +// CHECK-LABEL: func.func @test_tensor_reshape_unranked +// CHECK: tensor.reshape +func.func @test_tensor_reshape_unranked(%src: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor { + %dst = tensor.reshape %src(%shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + return %dst : tensor +}