Skip to content

Commit 204234a

Browse files
[mlir][SparseTensor][NFC] Pass tensor type to descriptor helper (#116468)
`getDescriptorFromTensorTuple` and `getMutDescriptorFromTensorTuple` extract the tensor type from an `unrealized_conversion_cast` op that serves as a workaround for missing 1:N dialect conversion support. This commit changes these functions so that they explicitly receive the tensor type as a function argument. This is in preparation of merging the 1:1 and 1:N conversion drivers. The conversion patterns in this file will soon start receiving multiple SSA values (`ValueRange`) from their adaptors (instead of a single value that is the result of `unrealized_conversion_cast`). It will no longer be possible to take the tensor type from the `unrealized_conversion_cast` op. The `unrealized_conversion_cast` workaround will disappear entirely.
1 parent 5ae4d50 commit 204234a

File tree

4 files changed

+44
-34
lines changed

4 files changed

+44
-34
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
646646
matchAndRewrite(LvlOp op, OpAdaptor adaptor,
647647
ConversionPatternRewriter &rewriter) const override {
648648
std::optional<int64_t> lvl = op.getConstantLvlIndex();
649-
if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
649+
RankedTensorType srcType = op.getSource().getType();
650+
if (!lvl || !getSparseTensorEncoding(srcType))
650651
return failure();
651652

652-
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
653+
auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
653654
auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
654655

655656
rewriter.replaceOp(op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
675676
assert(dstStt.hasSameDimToLvl(srcStt));
676677

677678
// We don't need a mutable descriptor here as we perform sorting in-place.
678-
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
679-
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
679+
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
680+
op.getInputCoo().getType());
681+
auto nnz = desc.getValMemSize(rewriter, op.getLoc());
680682
auto crd = desc.getAOSMemRef();
681683
auto val = desc.getValMemRef();
682684

@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
704706
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
705707
ConversionPatternRewriter &rewriter) const override {
706708
// Simply lowers to specifer.get <field> operation.
707-
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
709+
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
710+
op.getSlice().getType());
708711
auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
709712
op.getDim().getZExtValue());
710713

@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
762765
Location loc = op.getLoc();
763766
// Deal with copy.
764767
if (op.getCopy()) {
765-
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
768+
auto desc = getDescriptorFromTensorTuple(
769+
adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
766770
SmallVector<Value> fields;
767771
fields.reserve(desc.getNumFields());
768772
// Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
868872
if (createDeallocs) {
869873
// Replace the sparse tensor deallocation with field deallocations.
870874
Location loc = op.getLoc();
871-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
875+
auto desc = getDescriptorFromTensorTuple(
876+
adaptor.getTensor(),
877+
cast<RankedTensorType>(op.getTensor().getType()));
872878
for (auto input : desc.getMemRefFields())
873879
// Deallocate every buffer used to store the sparse tensor handler.
874880
rewriter.create<memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
889895
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
890896
ConversionPatternRewriter &rewriter) const override {
891897
// Prepare descriptor.
892-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
898+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
899+
op.getTensor().getType());
893900
// Generate optional insertion finalization code.
894901
if (op.getHasInserts())
895902
genEndInsert(rewriter, op.getLoc(), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
909916
if (!getSparseTensorEncoding(op.getTensor().getType()))
910917
return failure();
911918
Location loc = op->getLoc();
912-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
919+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
920+
op.getTensor().getType());
913921
const auto srcType = getSparseTensorType(op.getTensor());
914922
Type eltType = srcType.getElementType();
915923
Type boolType = rewriter.getIntegerType(1);
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
959967
ConversionPatternRewriter &rewriter) const override {
960968
Location loc = op->getLoc();
961969
SmallVector<Value> fields;
962-
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
970+
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
971+
op.getTensor().getType());
963972
Value values = adaptor.getValues();
964973
Value filled = adaptor.getFilled();
965974
Value added = adaptor.getAdded();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10321041
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
10331042

10341043
Location loc = op.getLoc();
1035-
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
1044+
auto desc =
1045+
getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
10361046
TypeRange flatSpTensorTps = desc.getFields().getTypes();
10371047
SmallVector<Value> params = llvm::to_vector(desc.getFields());
10381048
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
10591069
// of this operation truly observe size, not capacity!
10601070
Location loc = op.getLoc();
10611071
Level lvl = op.getLevel();
1062-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1072+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1073+
op.getTensor().getType());
10631074
auto mem = desc.getPosMemRef(lvl);
10641075
auto size = desc.getPosMemSize(rewriter, loc, lvl);
10651076
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
10811092
// of this operation truly observe size, not capacity!
10821093
Location loc = op.getLoc();
10831094
Level lvl = op.getLevel();
1084-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1095+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1096+
op.getTensor().getType());
10851097
auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
10861098
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
10871099
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
11061118
// of this operation truly observe size, not capacity!
11071119
Location loc = op.getLoc();
11081120
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1109-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1121+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1122+
op.getTensor().getType());
11101123
auto mem = desc.getAOSMemRef();
11111124
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
11121125
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
11261139
// The view is restricted to the actual size to ensure clients
11271140
// of this operation truly observe size, not capacity!
11281141
Location loc = op.getLoc();
1129-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1142+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1143+
op.getTensor().getType());
11301144
auto mem = desc.getValMemRef();
11311145
auto size = desc.getValMemSize(rewriter, loc);
11321146
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11721186
// else:
11731187
// dst = memref.copy(src)
11741188
Location loc = op.getLoc();
1175-
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
1189+
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1190+
op.getSource().getType());
11761191
SmallVector<Value> fields;
11771192
foreachFieldAndTypeInSparseTensor(
11781193
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
12361251
assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
12371252

12381253
SmallVector<Value> fields;
1239-
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
1254+
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1255+
op.getSource().getType());
12401256

12411257
auto newSpec = rewriter.create<StorageSpecifierInitOp>(
12421258
loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
12851301
// Query memSizes for the actually stored values.
12861302
// FIXME: the nse value computed in this way might be wrong when there is
12871303
// any "loose_compressed" level.
1288-
rewriter.replaceOp(
1289-
op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
1304+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1305+
op.getTensor().getType());
1306+
rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
12901307
return success();
12911308
}
12921309
};
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
14151432
LogicalResult
14161433
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
14171434
ConversionPatternRewriter &rewriter) const override {
1418-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1435+
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1436+
op.getTensor().getType());
14191437
Location loc = op.getLoc();
14201438
SmallVector<Value> retMem;
14211439
SmallVector<Value> retLen;

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
554554
.getResult();
555555
}
556556

557-
Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
558-
Value tensor) {
559-
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
560-
}
561-
562557
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
563558
Value tensor, Dimension dim) {
564559
auto enc = getSparseTensorEncoding(tensor.getType());

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
270270
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
271271
Value tensor);
272272

273-
/// Generates code to retrieve the values size for the sparse tensor.
274-
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
275-
276273
/// Generates code to retrieve the slice offset for the sparse tensor slice,
277274
/// return a constant if the offset is statically known.
278275
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,18 +245,18 @@ inline Value genTuple(OpBuilder &builder, Location loc,
245245
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
246246
}
247247

248-
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
248+
inline SparseTensorDescriptor
249+
getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
249250
auto tuple = getTuple(tensor);
250-
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
251-
return SparseTensorDescriptor(stt, tuple.getInputs());
251+
return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
252252
}
253253

254254
inline MutSparseTensorDescriptor
255-
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
255+
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
256+
RankedTensorType type) {
256257
auto tuple = getTuple(tensor);
257258
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
258-
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
259-
return MutSparseTensorDescriptor(stt, fields);
259+
return MutSparseTensorDescriptor(SparseTensorType(type), fields);
260260
}
261261

262262
} // namespace sparse_tensor

0 commit comments

Comments
 (0)