diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index eb7c50ae2efdf..f102f02701542 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -163,13 +163,15 @@ bool isBlockSparsity(AffineMap dimToLvl); // Reordering. // -/// [deprecated] Convenience method to translate the given level to the -/// corresponding dimension. Requires: `0 <= l < lvlRank`. -Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l); - -/// [deprecated] Convenience method to translate the given dimension to -/// the corresponding level. Requires: `0 <= d < dimRank`. -Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d); +/// Convenience method to translate the given level to the corresponding +/// dimension. +/// Requires: `enc` has a permuted dim2lvl map and `0 <= l < lvlRank`. +Dimension toDim(SparseTensorEncodingAttr enc, Level l); + +/// Convenience method to translate the given dimension to the corresponding +/// level. +/// Requires: `enc` has a permuted dim2lvl map and `0 <= d < dimRank`. +Level toLvl(SparseTensorEncodingAttr enc, Dimension d); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 791aeebee5a32..fc897e7935510 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -375,14 +375,12 @@ SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { std::optional SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { - // FIXME: `toOrigDim` is deprecated. - return getStaticDimSliceOffset(toOrigDim(*this, lvl)); + return getStaticDimSliceOffset(toDim(*this, lvl)); } std::optional SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { - // FIXME: `toOrigDim` is deprecated. - return getStaticDimSliceStride(toOrigDim(*this, lvl)); + return getStaticDimSliceStride(toDim(*this, lvl)); } SmallVector @@ -398,10 +396,8 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef srcShape, if (isPermutation()) { for (unsigned r = 0; r < rank; r++) { - // FIXME: `toOrigDim` and `toStoredDim` are deprecated. - unsigned trans = dir == CrdTransDirectionKind::dim2lvl - ? toOrigDim(*this, r) - : toStoredDim(*this, r); + unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) + : toLvl(*this, r); ret.push_back(srcShape[trans]); } return ret; @@ -925,31 +921,20 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, ordered); } -// TODO: Remove this definition once all use-sites have been fixed to -// properly handle non-permutations. -Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, - Level l) { +Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { if (enc) { - if (const auto dimToLvl = enc.getDimToLvl()) { - assert(enc.isPermutation()); + assert(enc.isPermutation() && "Non permutation map not supported"); + if (const auto dimToLvl = enc.getDimToLvl()) return dimToLvl.getDimPosition(l); - } } return l; } -// TODO: Remove this definition once all use-sites have been fixed to -// properly handle non-permutations. -Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, - Dimension d) { +Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { if (enc) { - if (const auto dimToLvl = enc.getDimToLvl()) { - assert(enc.isPermutation()); - auto maybePos = - dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext())); - assert(maybePos.has_value()); - return *maybePos; - } + assert(enc.isPermutation() && "Non permutation map not supported"); + if (const auto lvlToDim = enc.getLvlToDim()) + return lvlToDim.getDimPosition(d); } return d; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 1200b999f9a90..33d449aac5a35 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -546,32 +546,6 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, } } -Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - ValueRange dimSizes, - Value valuesBuffer, - Value lvlCoords) { - // Reuse the `lvlCoords` buffer to store the level-sizes. - const Level lvlRank = enc.getLvlRank(); - SmallVector lvlSizes; - lvlSizes.reserve(lvlRank); - for (Level l = 0; l < lvlRank; l++) - // FIXME: `toOrigDim` is deprecated. - lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]); - storeAll(builder, loc, lvlCoords, lvlSizes); - // The memref ReshapeOp requires the sizes buffer to have a static - // shape. - const auto iTp = builder.getIndexType(); - const SmallVector lvlSizesShape{static_cast(lvlRank)}; - const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp); - lvlCoords = builder.create(loc, lvlSizesTp, lvlCoords); - // Finally, create the ReshapeOp. - const SmallVector resShape(lvlRank, ShapedType::kDynamic); - const Type elemTp = getMemRefType(valuesBuffer).getElementType(); - const auto resTp = MemRefType::get(resShape, elemTp); - return builder.create(loc, resTp, valuesBuffer, lvlCoords); -} - TypedValue sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { auto tTp = llvm::cast(tensor.getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index cb0acdd2be9f7..0ce33427281f5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -277,13 +277,6 @@ SmallVector loadAll(OpBuilder &builder, Location loc, size_t size, void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx = 0, Value offsetVal = Value()); -/// Reshapes the linear values buffer for an annotated all dense sparse tensor -/// to match the shape of the corresponding dense tensor to support direct -/// access of the buffer through `lvlCoords`. -Value reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ValueRange dimSizes, - Value valuesBuffer, Value lvlCoords); - // Generates code to cast a tensor to a memref. TypedValue genToMemref(OpBuilder &builder, Location loc, Value tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h index 613a8609ac097..52ee117029300 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/IterationGraphSorter.h @@ -26,7 +26,6 @@ enum class SortMask : unsigned { // The individual mask bits. kIncludeDenseOutput = 0x1, // b001 kIncludeDenseInput = 0x2, // b010 - kIncludeUndef = 0x4, // b100 // The subsets of mask bits. kIncludeAll = 0x7, // b111 kIncludeDense = 0x3, // b011 diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index f8bcc0fe12a10..413a835ff14d3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -68,15 +68,13 @@ static constexpr unsigned kSliceIterWidth = 3; static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); - // FIXME: `toOrigDim` is deprecated - return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl)); + return createOrFoldSliceOffsetOp(builder, loc, tensor, toDim(enc, lvl)); } static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); - // FIXME: `toOrigDim` is deprecated - return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl)); + return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl)); } /// Converts a coordinate relative to the slice to the coordinate relative diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 268bd8fbe2738..c94ef8b962877 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -422,10 +422,10 @@ struct GenericOpScheduler : public OpRewritePattern { // computation. Must be ordered from more strict to less strict. // Ideally (though might not be guaranteed), the earlier a constraint mask // can be satisfied, the faster the generated kernel will be. - const auto allMasks = { - SortMask::kIncludeAll, SortMask::kIncludeDense, - SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, - SortMask::kIncludeUndef, SortMask::kSparseOnly}; + const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense, + SortMask::kIncludeDenseInput, + SortMask::kIncludeDenseOutput, + SortMask::kSparseOnly}; for (const SortMask mask : allMasks) { order = scheduler.sort(mask); if (order) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 5374ab55c5c0d..103908b2cf5bd 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -661,8 +661,7 @@ struct TensorReshapeRewriter : public OpRewritePattern { SmallVector srcDcvs; srcDcvs.reserve(srcRank); for (Dimension d = 0; d < srcRank; d++) { - // FIXME: `toStoredDim` is deprecated - Level lvl = toStoredDim(encSrc, d); + Level lvl = toLvl(encSrc, d); srcDcvs.push_back(srcLcvs[lvl]); } @@ -766,8 +765,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern { SmallVector srcDcvs; srcDcvs.reserve(dimRank); for (Dimension d = 0; d < dimRank; d++) { - // FIXME: `toStoredDim` is deprecated - Level lvl = toStoredDim(encSrc, d); + Level lvl = toLvl(encSrc, d); srcDcvs.push_back(srcLcvs[lvl]); } SmallVector dstDcvs; @@ -872,9 +870,8 @@ struct SparseTensorDimOpRewriter : public OpRewritePattern { return failure(); if (stt.isPermutation()) { - // FIXME: `toStoredDim` is deprecated rewriter.replaceOpWithNewOp(op, op.getSource(), - toStoredDim(stt.getEncoding(), *dim)); + toLvl(stt.getEncoding(), *dim)); return success(); }