-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][SparseTensor][NFC] Pass tensor type to descriptor helper #116468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Matthias Springer (matthias-springer) Changes
This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values ( Full diff: https://github.com/llvm/llvm-project/pull/116468.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index bf7b3f9bec5586..25fca49cb0154a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> lvl = op.getConstantLvlIndex();
- if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
+ RankedTensorType srcType = op.getSource().getType();
+ if (!lvl || !getSparseTensorEncoding(srcType))
return failure();
- auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
rewriter.replaceOp(op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
- auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
- auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
+ op.getInputCoo().getType());
+ auto nnz = desc.getValMemSize(rewriter, op.getLoc());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply lowers to specifer.get <field> operation.
- auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
+ op.getSlice().getType());
auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
op.getDim().getZExtValue());
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
Location loc = op.getLoc();
// Deal with copy.
if (op.getCopy()) {
- auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
+ auto desc = getDescriptorFromTensorTuple(
+ adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
SmallVector<Value> fields;
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
if (createDeallocs) {
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(
+ adaptor.getTensor(),
+ cast<RankedTensorType>(op.getTensor().getType()));
for (auto input : desc.getMemRefFields())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
Location loc = op->getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
const auto srcType = getSparseTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
+ op.getTensor().getType());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
+ auto desc =
+ getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getPosMemRef(lvl);
auto size = desc.getPosMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getAOSMemRef();
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
auto mem = desc.getValMemRef();
auto size = desc.getValMemSize(rewriter, loc);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
// else:
// dst = memref.copy(src)
Location loc = op.getLoc();
- auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
+ op.getSource().getType());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
SmallVector<Value> fields;
- auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+ auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
+ op.getSource().getType());
auto newSpec = rewriter.create<StorageSpecifierInitOp>(
loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
// any "loose_compressed" level.
- rewriter.replaceOp(
- op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
+ rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
return success();
}
};
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
+ op.getTensor().getType());
Location loc = op.getLoc();
SmallVector<Value> retMem;
SmallVector<Value> retLen;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index de553a5f9bf08c..f92382472b4780 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}
-Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
- Value tensor) {
- return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
-}
-
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index d0ef8a6860bb2d..dc017e6baa6dc3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
-/// Generates code to retrieve the values size for the sparse tensor.
-Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
-
/// Generates code to retrieve the slice offset for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index c2f631605bf4b2..89858546e37e1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}
-inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
+inline SparseTensorDescriptor
+getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
auto tuple = getTuple(tensor);
- SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
- return SparseTensorDescriptor(stt, tuple.getInputs());
+ return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
}
inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
+getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+ RankedTensorType type) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
- return MutSparseTensorDescriptor(stt, fields);
+ return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}
} // namespace sparse_tensor
|
e872b86
to
66fb6bb
Compare
…#116468) `getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple` extract the tensor type from an `unrealized_conversion_cast` op that serves as a workaround for missing 1:N dialect conversion support. This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (`ValueRange`) from their adaptors (instead of a single value that is the result of `unrealized_conversion_cast`). It will no longer be possible to take the tensor type from the `unrealized_conversion_cast` op. The `unrealized_conversion_cast` workaround will disappear entirely.
…#116468) `getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple` extract the tensor type from an `unrealized_conversion_cast` op that serves as a workaround for missing 1:N dialect conversion support. This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (`ValueRange`) from their adaptors (instead of a single value that is the result of `unrealized_conversion_cast`). It will no longer be possible to take the tensor type from the `unrealized_conversion_cast` op. The `unrealized_conversion_cast` workaround will disappear entirely.
getDescriptorFromTensorTuple
andgetMutDescriptorFromTensorTuple
extract the tensor type from anunrealized_conversion_cast
op that serves as a workaround for missing 1:N dialect conversion support.This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (
ValueRange
) from their adaptors (instead of a single value that is the result ofunrealized_conversion_cast
). It will no longer be possible to take the tensor type from theunrealized_conversion_cast
op. Theunrealized_conversion_cast
workaround will disappear entirely.