@@ -646,10 +646,11 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
646
646
matchAndRewrite (LvlOp op, OpAdaptor adaptor,
647
647
ConversionPatternRewriter &rewriter) const override {
648
648
std::optional<int64_t > lvl = op.getConstantLvlIndex ();
649
- if (!lvl || !getSparseTensorEncoding (adaptor.getSource ().getType ()))
649
+ RankedTensorType srcType = op.getSource ().getType ();
650
+ if (!lvl || !getSparseTensorEncoding (srcType))
650
651
return failure ();
651
652
652
- auto desc = getDescriptorFromTensorTuple (adaptor.getSource ());
653
+ auto desc = getDescriptorFromTensorTuple (adaptor.getSource (), srcType );
653
654
auto sz = desc.getLvlSize (rewriter, op.getLoc (), *lvl);
654
655
655
656
rewriter.replaceOp (op, sz);
@@ -675,8 +676,9 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
675
676
assert (dstStt.hasSameDimToLvl (srcStt));
676
677
677
678
// We don't need a mutable descriptor here as we perform sorting in-place.
678
- auto nnz = genValMemSize (rewriter, op.getLoc (), adaptor.getInputCoo ());
679
- auto desc = getDescriptorFromTensorTuple (adaptor.getInputCoo ());
679
+ auto desc = getDescriptorFromTensorTuple (adaptor.getInputCoo (),
680
+ op.getInputCoo ().getType ());
681
+ auto nnz = desc.getValMemSize (rewriter, op.getLoc ());
680
682
auto crd = desc.getAOSMemRef ();
681
683
auto val = desc.getValMemRef ();
682
684
@@ -704,7 +706,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
704
706
matchAndRewrite (Op op, typename Op::Adaptor adaptor,
705
707
ConversionPatternRewriter &rewriter) const override {
706
708
// Simply lowers to specifer.get <field> operation.
707
- auto desc = getDescriptorFromTensorTuple (adaptor.getSlice ());
709
+ auto desc = getDescriptorFromTensorTuple (adaptor.getSlice (),
710
+ op.getSlice ().getType ());
708
711
auto v = desc.getSpecifierField (rewriter, op.getLoc (), kind,
709
712
op.getDim ().getZExtValue ());
710
713
@@ -762,7 +765,8 @@ class SparseTensorAllocConverter
762
765
Location loc = op.getLoc ();
763
766
// Deal with copy.
764
767
if (op.getCopy ()) {
765
- auto desc = getDescriptorFromTensorTuple (adaptor.getCopy ());
768
+ auto desc = getDescriptorFromTensorTuple (
769
+ adaptor.getCopy (), cast<RankedTensorType>(op.getCopy ().getType ()));
766
770
SmallVector<Value> fields;
767
771
fields.reserve (desc.getNumFields ());
768
772
// Memcpy on memref fields.
@@ -868,7 +872,9 @@ class SparseTensorDeallocConverter
868
872
if (createDeallocs) {
869
873
// Replace the sparse tensor deallocation with field deallocations.
870
874
Location loc = op.getLoc ();
871
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
875
+ auto desc = getDescriptorFromTensorTuple (
876
+ adaptor.getTensor (),
877
+ cast<RankedTensorType>(op.getTensor ().getType ()));
872
878
for (auto input : desc.getMemRefFields ())
873
879
// Deallocate every buffer used to store the sparse tensor handler.
874
880
rewriter.create <memref::DeallocOp>(loc, input);
@@ -889,7 +895,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
889
895
matchAndRewrite (LoadOp op, OpAdaptor adaptor,
890
896
ConversionPatternRewriter &rewriter) const override {
891
897
// Prepare descriptor.
892
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
898
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
899
+ op.getTensor ().getType ());
893
900
// Generate optional insertion finalization code.
894
901
if (op.getHasInserts ())
895
902
genEndInsert (rewriter, op.getLoc (), desc);
@@ -909,7 +916,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
909
916
if (!getSparseTensorEncoding (op.getTensor ().getType ()))
910
917
return failure ();
911
918
Location loc = op->getLoc ();
912
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
919
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
920
+ op.getTensor ().getType ());
913
921
const auto srcType = getSparseTensorType (op.getTensor ());
914
922
Type eltType = srcType.getElementType ();
915
923
Type boolType = rewriter.getIntegerType (1 );
@@ -959,7 +967,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
959
967
ConversionPatternRewriter &rewriter) const override {
960
968
Location loc = op->getLoc ();
961
969
SmallVector<Value> fields;
962
- auto desc = getMutDescriptorFromTensorTuple (adaptor.getTensor (), fields);
970
+ auto desc = getMutDescriptorFromTensorTuple (adaptor.getTensor (), fields,
971
+ op.getTensor ().getType ());
963
972
Value values = adaptor.getValues ();
964
973
Value filled = adaptor.getFilled ();
965
974
Value added = adaptor.getAdded ();
@@ -1032,7 +1041,8 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1032
1041
assert (stt.isIdentity () && " Run reinterpret-map before conversion." );
1033
1042
1034
1043
Location loc = op.getLoc ();
1035
- auto desc = getDescriptorFromTensorTuple (adaptor.getDest ());
1044
+ auto desc =
1045
+ getDescriptorFromTensorTuple (adaptor.getDest (), op.getDest ().getType ());
1036
1046
TypeRange flatSpTensorTps = desc.getFields ().getTypes ();
1037
1047
SmallVector<Value> params = llvm::to_vector (desc.getFields ());
1038
1048
params.append (adaptor.getIndices ().begin (), adaptor.getIndices ().end ());
@@ -1059,7 +1069,8 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1059
1069
// of this operation truly observe size, not capacity!
1060
1070
Location loc = op.getLoc ();
1061
1071
Level lvl = op.getLevel ();
1062
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1072
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1073
+ op.getTensor ().getType ());
1063
1074
auto mem = desc.getPosMemRef (lvl);
1064
1075
auto size = desc.getPosMemSize (rewriter, loc, lvl);
1065
1076
rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1081,7 +1092,8 @@ class SparseToCoordinatesConverter
1081
1092
// of this operation truly observe size, not capacity!
1082
1093
Location loc = op.getLoc ();
1083
1094
Level lvl = op.getLevel ();
1084
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1095
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1096
+ op.getTensor ().getType ());
1085
1097
auto mem = desc.getCrdMemRefOrView (rewriter, loc, lvl);
1086
1098
if (lvl < getSparseTensorType (op.getTensor ()).getAoSCOOStart ()) {
1087
1099
auto size = desc.getCrdMemSize (rewriter, loc, lvl);
@@ -1106,7 +1118,8 @@ class SparseToCoordinatesBufferConverter
1106
1118
// of this operation truly observe size, not capacity!
1107
1119
Location loc = op.getLoc ();
1108
1120
Level lvl = getSparseTensorType (op.getTensor ()).getAoSCOOStart ();
1109
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1121
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1122
+ op.getTensor ().getType ());
1110
1123
auto mem = desc.getAOSMemRef ();
1111
1124
auto size = desc.getCrdMemSize (rewriter, loc, lvl);
1112
1125
rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1126,7 +1139,8 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1126
1139
// The view is restricted to the actual size to ensure clients
1127
1140
// of this operation truly observe size, not capacity!
1128
1141
Location loc = op.getLoc ();
1129
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1142
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1143
+ op.getTensor ().getType ());
1130
1144
auto mem = desc.getValMemRef ();
1131
1145
auto size = desc.getValMemSize (rewriter, loc);
1132
1146
rewriter.replaceOp (op, genSliceToSize (rewriter, loc, mem, size));
@@ -1172,7 +1186,8 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1172
1186
// else:
1173
1187
// dst = memref.copy(src)
1174
1188
Location loc = op.getLoc ();
1175
- auto srcDesc = getDescriptorFromTensorTuple (adaptor.getSource ());
1189
+ auto srcDesc = getDescriptorFromTensorTuple (adaptor.getSource (),
1190
+ op.getSource ().getType ());
1176
1191
SmallVector<Value> fields;
1177
1192
foreachFieldAndTypeInSparseTensor (
1178
1193
SparseTensorType (cast<RankedTensorType>(op.getResult ().getType ())),
@@ -1236,7 +1251,8 @@ class SparseExtractSliceConverter
1236
1251
assert (srcEnc.withoutDimSlices () == dstEnc.withoutDimSlices ());
1237
1252
1238
1253
SmallVector<Value> fields;
1239
- auto desc = getMutDescriptorFromTensorTuple (adaptor.getSource (), fields);
1254
+ auto desc = getMutDescriptorFromTensorTuple (adaptor.getSource (), fields,
1255
+ op.getSource ().getType ());
1240
1256
1241
1257
auto newSpec = rewriter.create <StorageSpecifierInitOp>(
1242
1258
loc, StorageSpecifierType::get (ctx, dstEnc), desc.getSpecifier ());
@@ -1285,8 +1301,9 @@ class SparseNumberOfEntriesConverter
1285
1301
// Query memSizes for the actually stored values.
1286
1302
// FIXME: the nse value computed in this way might be wrong when there is
1287
1303
// any "loose_compressed" level.
1288
- rewriter.replaceOp (
1289
- op, genValMemSize (rewriter, op.getLoc (), adaptor.getTensor ()));
1304
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1305
+ op.getTensor ().getType ());
1306
+ rewriter.replaceOp (op, desc.getValMemSize (rewriter, op.getLoc ()));
1290
1307
return success ();
1291
1308
}
1292
1309
};
@@ -1415,7 +1432,8 @@ struct SparseDisassembleOpConverter
1415
1432
LogicalResult
1416
1433
matchAndRewrite (DisassembleOp op, OpAdaptor adaptor,
1417
1434
ConversionPatternRewriter &rewriter) const override {
1418
- auto desc = getDescriptorFromTensorTuple (adaptor.getTensor ());
1435
+ auto desc = getDescriptorFromTensorTuple (adaptor.getTensor (),
1436
+ op.getTensor ().getType ());
1419
1437
Location loc = op.getLoc ();
1420
1438
SmallVector<Value> retMem;
1421
1439
SmallVector<Value> retLen;
0 commit comments