Skip to content

[mlir][SparseTensor][NFC] Pass tensor type to descriptor helper #116468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 16, 2024

getDescriptorFromTensorTuple and getMutDescriptorFromTensorTuple extract the tensor type from an unrealized_conversion_cast op that serves as a workaround for missing 1:N dialect conversion support.

This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (ValueRange) from their adaptors (instead of a single value that is the result of unrealized_conversion_cast). It will no longer be possible to take the tensor type from the unrealized_conversion_cast op. The unrealized_conversion_cast workaround will disappear entirely.

@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

getDescriptorFromTensorTuple and getMutDescriptorFromTensorTuple extract the tensor type from an unrealized_conversion_cast op that serves as a workaround for missing 1:N dialect conversion support.

This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (ValueRange) from their adaptors (instead of a single value that is the result of unrealized_conversion_cast). It will no longer be possible to take the tensor type from the unrealized_conversion_cast op. The unrealized_conversion_cast workaround will disappear entirely.


Full diff: https://github.com/llvm/llvm-project/pull/116468.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+38-20)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (-5)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h (+6-6)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index bf7b3f9bec5586..25fca49cb0154a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
   matchAndRewrite(LvlOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     std::optional<int64_t> lvl = op.getConstantLvlIndex();
-    if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
+    RankedTensorType srcType = op.getSource().getType();
+    if (!lvl || !getSparseTensorEncoding(srcType))
       return failure();
 
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
     auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
 
     rewriter.replaceOp(op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
     assert(dstStt.hasSameDimToLvl(srcStt));
 
     // We don't need a mutable descriptor here as we perform sorting in-place.
-    auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
-    auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
+                                             op.getInputCoo().getType());
+    auto nnz = desc.getValMemSize(rewriter, op.getLoc());
     auto crd = desc.getAOSMemRef();
     auto val = desc.getValMemRef();
 
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply lowers to specifer.get <field> operation.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
+                                             op.getSlice().getType());
     auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
                                     op.getDim().getZExtValue());
 
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
     Location loc = op.getLoc();
     // Deal with copy.
     if (op.getCopy()) {
-      auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
+      auto desc = getDescriptorFromTensorTuple(
+          adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
       SmallVector<Value> fields;
       fields.reserve(desc.getNumFields());
       // Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
     if (createDeallocs) {
       // Replace the sparse tensor deallocation with field deallocations.
       Location loc = op.getLoc();
-      auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+      auto desc = getDescriptorFromTensorTuple(
+          adaptor.getTensor(),
+          cast<RankedTensorType>(op.getTensor().getType()));
       for (auto input : desc.getMemRefFields())
         // Deallocate every buffer used to store the sparse tensor handler.
         rewriter.create<memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Prepare descriptor.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     if (!getSparseTensorEncoding(op.getTensor().getType()))
       return failure();
     Location loc = op->getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     const auto srcType = getSparseTensorType(op.getTensor());
     Type eltType = srcType.getElementType();
     Type boolType = rewriter.getIntegerType(1);
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
+                                                op.getTensor().getType());
     Value values = adaptor.getValues();
     Value filled = adaptor.getFilled();
     Value added = adaptor.getAdded();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
 
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
+    auto desc =
+        getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
     TypeRange flatSpTensorTps = desc.getFields().getTypes();
     SmallVector<Value> params = llvm::to_vector(desc.getFields());
     params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getPosMemRef(lvl);
     auto size = desc.getPosMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
     if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
       auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getAOSMemRef();
     auto size = desc.getCrdMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     auto mem = desc.getValMemRef();
     auto size = desc.getValMemSize(rewriter, loc);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
     //   else:
     //     dst = memref.copy(src)
     Location loc = op.getLoc();
-    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
+                                                op.getSource().getType());
     SmallVector<Value> fields;
     foreachFieldAndTypeInSparseTensor(
         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
     assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
 
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
+                                                op.getSource().getType());
 
     auto newSpec = rewriter.create<StorageSpecifierInitOp>(
         loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
     // Query memSizes for the actually stored values.
     // FIXME: the nse value computed in this way might be wrong when there is
     // any "loose_compressed" level.
-    rewriter.replaceOp(
-        op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
+    rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
     return success();
   }
 };
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
   LogicalResult
   matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+                                             op.getTensor().getType());
     Location loc = op.getLoc();
     SmallVector<Value> retMem;
     SmallVector<Value> retLen;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index de553a5f9bf08c..f92382472b4780 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
       .getResult();
 }
 
-Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
-                                   Value tensor) {
-  return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
-}
-
 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
                                                Value tensor, Dimension dim) {
   auto enc = getSparseTensorEncoding(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index d0ef8a6860bb2d..dc017e6baa6dc3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
 TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                        Value tensor);
 
-/// Generates code to retrieve the values size for the sparse tensor.
-Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
-
 /// Generates code to retrieve the slice offset for the sparse tensor slice,
 /// return a constant if the offset is statically known.
 Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index c2f631605bf4b2..89858546e37e1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
   return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
 }
 
-inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+inline SparseTensorDescriptor
+getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
   auto tuple = getTuple(tensor);
-  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
-  return SparseTensorDescriptor(stt, tuple.getInputs());
+  return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
 }
 
 inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+                                RankedTensorType type) {
   auto tuple = getTuple(tensor);
   fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
-  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
-  return MutSparseTensorDescriptor(stt, fields);
+  return MutSparseTensorDescriptor(SparseTensorType(type), fields);
 }
 
 } // namespace sparse_tensor

@matthias-springer matthias-springer force-pushed the users/matthias-springer/sparse_dec_tt branch from e872b86 to 66fb6bb Compare November 18, 2024 11:55
@matthias-springer matthias-springer merged commit 204234a into main Nov 19, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/sparse_dec_tt branch November 19, 2024 00:27
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Nov 20, 2024
…#116468)

`getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple`
extract the tensor type from an `unrealized_conversion_cast` op that
serves as a workaround for missing 1:N dialect conversion support.

This commit changes these functions so that they explicitly receive the
tensor type as a function argument. This is in preparation of merging
the 1:1 and 1:N conversion drivers. The conversion patterns in this file
will soon start receiving multiple SSA values (`ValueRange`) from their
adaptors (instead of a single value that is the result of
`unrealized_conversion_cast`). It will no longer be possible to take the
tensor type from the `unrealized_conversion_cast` op. The
`unrealized_conversion_cast` workaround will disappear entirely.
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Nov 20, 2024
…#116468)

`getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple`
extract the tensor type from an `unrealized_conversion_cast` op that
serves as a workaround for missing 1:N dialect conversion support.

This commit changes these functions so that they explicitly receive the
tensor type as a function argument. This is in preparation of merging
the 1:1 and 1:N conversion drivers. The conversion patterns in this file
will soon start receiving multiple SSA values (`ValueRange`) from their
adaptors (instead of a single value that is the result of
`unrealized_conversion_cast`). It will no longer be possible to take the
tensor type from the `unrealized_conversion_cast` op. The
`unrealized_conversion_cast` workaround will disappear entirely.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants