@@ -162,38 +162,6 @@ static SmallVector<Value> getDimShape(OpBuilder &builder, Location loc,
162
162
return out;
163
163
}
164
164
165
- // / Populates the given sizes array for concatenation from type (for static
166
- // / sizes) and from an already-converted opaque pointer source (for dynamic
167
- // / sizes).
168
- static void concatDimSizesFromInputs (OpBuilder &builder, Location loc,
169
- SparseTensorType dstTp, ValueRange srcs,
170
- Dimension dim,
171
- SmallVectorImpl<Value> &dimSizes) {
172
- assert (dim < dstTp.getDimRank () && " Dimension is out of bounds" );
173
- dimSizes.clear ();
174
-
175
- // We first fills the sizes from an input tensor, and then
176
- // compute the size of the concatenation dimension if necessary.
177
- const auto srcTp = getSparseTensorType (srcs[0 ]);
178
- if (srcTp.hasEncoding ())
179
- // Reuses sizes from an arbitrary input tensor is fine.
180
- fillDimSizes (builder, loc, srcTp, srcs[0 ], dimSizes);
181
- else
182
- sizesFromSrc (builder, dimSizes, loc, srcs[0 ]);
183
-
184
- if (const auto sz = dstTp.getStaticDimSize (dim)) {
185
- // Faithfully take the static size.
186
- dimSizes[dim] = constantIndex (builder, loc, *sz);
187
- } else {
188
- // Else, dynamically compute the size.
189
- for (const auto src : srcs.drop_front ()) {
190
- const auto srcTp = getSparseTensorType (src);
191
- Value srcSz = createOrFoldDimCall (builder, loc, srcTp, src, dim);
192
- dimSizes[dim] = builder.create <arith::AddIOp>(loc, dimSizes[dim], srcSz);
193
- }
194
- }
195
- }
196
-
197
165
// / Generates an uninitialized buffer of the given size and type,
198
166
// / but returns it as type `memref<? x $tp>` (rather than as type
199
167
// / `memref<$sz x $tp>`). Unlike temporary buffers on the stack,
@@ -467,107 +435,6 @@ static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
467
435
return true ;
468
436
}
469
437
470
- // Generates a while loop that iterates over the COO list extracted
471
- // from `t`, using `bodyBuilder` to build the loop body.
472
- // while (elem = coo->getNext()) {
473
- // bodyBuilder
474
- // }
475
- // TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
476
- // reduce code repetition!
477
- // TODO: rename to `genSparseIterationLoop`?
478
- static void genSparseCOOIterationLoop (
479
- ConversionPatternRewriter &rewriter, Location loc, Value t,
480
- SparseTensorType stt,
481
- function_ref<void (OpBuilder &, Location, Value, Value)> bodyBuilder) {
482
- assert (stt.hasEncoding () &&
483
- " Generating Sparse Tensor COO Loop on a Dense Tensor!" );
484
- const Dimension dimRank = stt.getDimRank ();
485
- const Type elemTp = stt.getElementType ();
486
-
487
- // Start an iterator over the tensor (in coordinate order).
488
- const auto noPerm = stt.withoutDimToLvl ();
489
- SmallVector<Value> dimSizes = getDimSizes (rewriter, loc, noPerm, t);
490
- Value iter = NewCallParams (rewriter, loc)
491
- .genBuffers (noPerm, dimSizes)
492
- .genNewCall (Action::kToIterator , t);
493
-
494
- // Construct a while loop over the iterator.
495
- const Type iTp = rewriter.getIndexType ();
496
- Value srcDimCoords = genAlloca (rewriter, loc, dimRank, iTp);
497
- Value elemPtr = genAllocaScalar (rewriter, loc, elemTp);
498
- const SmallVector<Value> noArgs;
499
- const SmallVector<Type> noTypes;
500
- auto whileOp = rewriter.create <scf::WhileOp>(loc, noTypes, noArgs);
501
- Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, noTypes);
502
- rewriter.setInsertionPointToEnd (before);
503
- Value cond = genGetNextCall (rewriter, loc, iter, srcDimCoords, elemPtr);
504
- rewriter.create <scf::ConditionOp>(loc, cond, before->getArguments ());
505
- Block *after = rewriter.createBlock (&whileOp.getAfter (), {}, noTypes);
506
- rewriter.setInsertionPointToStart (after);
507
-
508
- const bool hasDenseDim =
509
- llvm::any_of (stt.getEncoding ().getLvlTypes (), isDenseDLT);
510
- if (hasDenseDim) {
511
- Value elemV = rewriter.create <memref::LoadOp>(loc, elemPtr);
512
- Value isZero = genIsNonzero (rewriter, loc, elemV);
513
- scf::IfOp ifOp = rewriter.create <scf::IfOp>(loc, isZero, /* else*/ false );
514
- rewriter.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
515
- }
516
- // Callback here to build loop body.
517
- bodyBuilder (rewriter, loc, srcDimCoords, elemPtr);
518
-
519
- // Exit the scope from the IfOp.
520
- if (hasDenseDim)
521
- rewriter.setInsertionPointToEnd (after);
522
-
523
- rewriter.create <scf::YieldOp>(loc);
524
- // Finish generating loop.
525
- rewriter.setInsertionPointAfter (whileOp);
526
-
527
- // Free memory for iterator.
528
- genDelIteratorCall (rewriter, loc, elemTp, iter);
529
- }
530
-
531
- // Generate loop that iterates over a dense tensor.
532
- // for i1 in dim1
533
- // ..
534
- // for ik in dimk
535
- // val = a[i1,..,ik]
536
- // if val != 0
537
- // bodyBuilder(v, [i1, ..., ik])
538
- // TODO: It can be used by other operators (ReshapeOp, ConvertOP) conversion to
539
- // reduce code repetition!
540
- static void genDenseTensorIterationLoop (
541
- ConversionPatternRewriter &rewriter, Location loc, Value t,
542
- SparseTensorType stt,
543
- function_ref<void (OpBuilder &, Location, ValueRange)> bodyBuilder) {
544
- assert (!stt.hasEncoding () &&
545
- " Generating Dense Tensor Loop on a Sparse Tensor!" );
546
-
547
- const Dimension dimRank = stt.getDimRank ();
548
- Value zero = constantIndex (rewriter, loc, 0 );
549
- Value one = constantIndex (rewriter, loc, 1 );
550
-
551
- SmallVector<Value> lo;
552
- SmallVector<Value> hi;
553
- SmallVector<Value> st;
554
-
555
- // Fill out loop iteration information.
556
- for (Dimension d = 0 ; d < dimRank; d++) {
557
- lo.push_back (zero);
558
- hi.push_back (linalg::createOrFoldDimOp (rewriter, loc, t, d));
559
- st.push_back (one);
560
- }
561
-
562
- scf::buildLoopNest (rewriter, loc, lo, hi, st, {},
563
- [&](OpBuilder &builder, Location loc, ValueRange ivs,
564
- ValueRange args) -> scf::ValueVector {
565
- // Invoke callback to build the body of the loop.
566
- bodyBuilder (builder, loc, ivs);
567
- return {};
568
- });
569
- }
570
-
571
438
// ===----------------------------------------------------------------------===//
572
439
// Conversion rules.
573
440
// ===----------------------------------------------------------------------===//
@@ -1198,168 +1065,6 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
1198
1065
}
1199
1066
};
1200
1067
1201
- // / Sparse conversion rule for the concatenate operator.
1202
- class SparseTensorConcatConverter : public OpConversionPattern <ConcatenateOp> {
1203
- public:
1204
- using OpConversionPattern::OpConversionPattern;
1205
- LogicalResult
1206
- matchAndRewrite (ConcatenateOp op, OpAdaptor adaptor,
1207
- ConversionPatternRewriter &rewriter) const override {
1208
- // The conversion works as follow:
1209
- // (1). When output is sparse and not all dims are dense, and mix of inputs:
1210
- // a_sparse = concat (b_dense, c_sparse, ....)
1211
- // =>
1212
- // coo_for_a = newSparseCOO(shapeOf(a))
1213
- // for i, j, k // dense input
1214
- // coo->add(adjustForOffset(i,j,k), b[i,j,k])
1215
- //
1216
- // for elem in sparse_input
1217
- // coo->add(adjustForOffset(elem.coords), elem.value)
1218
- // ...
1219
- // a = newSparseTensor(coo_for_a)
1220
- // return a
1221
- //
1222
- // (2). When output is dense or annotated all dense, and mix of inputs:
1223
- // a_dense = concat (b_dense, c_sparse, ....)
1224
- // =>
1225
- // a = malloc(shapeOf(a)) or newSparseAllDense(shapeOf(a))
1226
- // for i, j, k // dense input
1227
- // a[ adjustForOffset(i,j,k) ] = b[i,j,k]
1228
- //
1229
- // for elem in sparse_input
1230
- // a[ adjustForOffset(elem.coords) ] = elem.value
1231
- // return a
1232
- Location loc = op.getLoc ();
1233
- const auto dstTp = getSparseTensorType (op);
1234
- const auto dstEnc = dstTp.getEncoding ();
1235
- const Type elemTp = dstTp.getElementType ();
1236
- const Dimension concatDim = op.getDimension ();
1237
- const Dimension dimRank = dstTp.getDimRank ();
1238
-
1239
- Value dst; // destination tensor
1240
- Value dstDimToLvl; // destination tensor permutation (if sparse out)
1241
- // A pointer to the value being inserted (if dense => sparse)
1242
- Value elemPtr;
1243
- // Memory that holds the dim-coords for destination tensor (if sparse out)
1244
- Value dstDimCoords;
1245
- // The offset applied to the dimension to be concated (starting from 0)
1246
- Value offset = constantIndex (rewriter, loc, 0 );
1247
-
1248
- SmallVector<Value> dimSizes;
1249
- concatDimSizesFromInputs (rewriter, loc, dstTp, op.getInputs (), concatDim,
1250
- dimSizes);
1251
-
1252
- NewCallParams params (rewriter, loc);
1253
- const bool allDense = dstTp.hasEncoding () && dstTp.isAllDense ();
1254
- Value dstTensor;
1255
- if (dstTp.hasEncoding ()) {
1256
- // Start a new COO or an initialized annotated all dense sparse tensor.
1257
- dst = params.genBuffers (dstTp, dimSizes)
1258
- .genNewCall (allDense ? Action::kEmpty : Action::kEmptyCOO );
1259
- dstDimCoords = genAlloca (rewriter, loc, dimRank, rewriter.getIndexType ());
1260
- if (allDense) {
1261
- dstTensor = dst;
1262
- // Get the values buffer for the sparse tensor and reshape it to the
1263
- // corresponding dense tensor shape.
1264
- dst = genValuesCall (rewriter, loc,
1265
- MemRefType::get ({ShapedType::kDynamic }, elemTp),
1266
- {dst});
1267
- // Pass the `dstDimCoords` buffer for `reshapeValuesToLevels`
1268
- // to reuse for storing level-sizes (yes, "level-sizes").
1269
- // This is safe to do because `dstTp` is a dense-tensor type,
1270
- // and therefore lvlRank == dimRank.
1271
- dst = reshapeValuesToLevels (rewriter, loc, dstEnc, dimSizes, dst,
1272
- dstDimCoords);
1273
- } else {
1274
- dstDimToLvl = params.getDimToLvl ();
1275
- elemPtr = genAllocaScalar (rewriter, loc, elemTp);
1276
- }
1277
- } else {
1278
- // TODO: Dense buffers should be allocated/deallocated via the callback
1279
- // in BufferizationOptions.
1280
- dst = allocDenseTensor (rewriter, loc, dstTp, dimSizes);
1281
- }
1282
- const Level lvlRank = dstTp.getLvlRank ();
1283
- const auto dcvs2lcvs = [&](ValueRange dcvs) -> SmallVector<Value> {
1284
- SmallVector<Value> lcvs;
1285
- lcvs.reserve (lvlRank);
1286
- for (Level l = 0 ; l < lvlRank; l++)
1287
- // FIXME: `toOrigDim` is deprecated
1288
- lcvs.push_back (dcvs[toOrigDim (dstEnc, l)]);
1289
- return lcvs;
1290
- };
1291
- for (const auto &it : llvm::zip (op.getInputs (), adaptor.getInputs ())) {
1292
- Value orignalOp = std::get<0 >(it); // Input (with encoding) from Op
1293
- Value adaptedOp = std::get<1 >(it); // Input (type converted) from adaptor
1294
- const auto srcTp = getSparseTensorType (orignalOp);
1295
- if (srcTp.hasEncoding ()) {
1296
- genSparseCOOIterationLoop (
1297
- rewriter, loc, adaptedOp, srcTp,
1298
- [&](OpBuilder &builder, Location loc, Value dimCoords,
1299
- Value elemPtr) -> void {
1300
- const auto dcvs =
1301
- loadAll (builder, loc, dimRank, dimCoords, concatDim, offset);
1302
- if (dstTp.hasEncoding () && !allDense) {
1303
- // Case: sparse => sparse, except for annotated all dense.
1304
- storeAll (builder, loc, dstDimCoords, dcvs);
1305
- genAddEltCall (builder, loc, elemTp, dst, elemPtr, dstDimCoords,
1306
- dstDimToLvl);
1307
- } else {
1308
- // Case: sparse => dense, or annotated all dense.
1309
- const auto lcvs = allDense ? dcvs2lcvs (dcvs) : dcvs;
1310
- insertScalarIntoDenseTensor (builder, loc, elemPtr, dst, lcvs);
1311
- }
1312
- });
1313
- } else {
1314
- genDenseTensorIterationLoop (
1315
- rewriter, loc, adaptedOp, srcTp,
1316
- [&](OpBuilder &builder, Location loc, ValueRange dcvs) -> void {
1317
- if (dstTp.hasEncoding () && !allDense) {
1318
- // Case: dense => sparse, except for annotated all dense.
1319
- assert (dcvs.size () == static_cast <size_t >(dimRank));
1320
- storeAll (builder, loc, dstDimCoords, dcvs, concatDim, offset);
1321
- Value val = genValueForDense (builder, loc, adaptedOp, dcvs);
1322
- builder.create <memref::StoreOp>(loc, val, elemPtr);
1323
- genAddEltCall (builder, loc, elemTp, dst, elemPtr, dstDimCoords,
1324
- dstDimToLvl);
1325
- } else {
1326
- // Case: dense => dense, or annotated all dense.
1327
- Value val = genValueForDense (builder, loc, adaptedOp, dcvs);
1328
- // Despite the name, this isn't actually level-cvs until
1329
- // after the `dcvs2lcvs` call.
1330
- SmallVector<Value> lcvs (dcvs);
1331
- // Apply offset.
1332
- lcvs[concatDim] =
1333
- builder.create <arith::AddIOp>(loc, lcvs[concatDim], offset);
1334
- if (allDense)
1335
- lcvs = dcvs2lcvs (lcvs);
1336
- builder.create <memref::StoreOp>(loc, val, dst, lcvs);
1337
- }
1338
- });
1339
- }
1340
- // Accumulate offset.
1341
- // TODO: avoid calling sparseDimSize multiple times by caching the result!
1342
- Value curDim =
1343
- createOrFoldDimCall (rewriter, loc, srcTp, adaptedOp, concatDim);
1344
- offset = rewriter.create <arith::AddIOp>(loc, offset, curDim);
1345
- }
1346
- if (!dstTp.hasEncoding ()) {
1347
- rewriter.replaceOpWithNewOp <bufferization::ToTensorOp>(
1348
- op, dstTp.getRankedTensorType (), dst);
1349
- } else if (allDense) {
1350
- rewriter.replaceOp (op, dstTensor);
1351
- } else {
1352
- // In sparse output case, the destination holds the COO.
1353
- Value coo = dst;
1354
- dst = params.genNewCall (Action::kFromCOO , coo);
1355
- // Release resources.
1356
- genDelCOOCall (rewriter, loc, elemTp, coo);
1357
- rewriter.replaceOp (op, dst);
1358
- }
1359
- return success ();
1360
- }
1361
- };
1362
-
1363
1068
// / Sparse conversion rule for the output operator.
1364
1069
class SparseTensorOutConverter : public OpConversionPattern <OutOp> {
1365
1070
public:
@@ -1434,17 +1139,16 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
1434
1139
void mlir::populateSparseTensorConversionPatterns (
1435
1140
TypeConverter &typeConverter, RewritePatternSet &patterns,
1436
1141
const SparseTensorConversionOptions &options) {
1437
- patterns
1438
- .add <SparseReturnConverter, SparseTensorToDimSizeConverter,
1439
- SparseCastConverter, SparseTensorNewConverter,
1440
- SparseTensorConcatConverter, SparseTensorAllocConverter,
1441
- SparseTensorEmptyConverter, SparseTensorDeallocConverter,
1442
- SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
1443
- SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1444
- SparseTensorLoadConverter, SparseTensorInsertConverter,
1445
- SparseTensorExpandConverter, SparseTensorCompressConverter,
1446
- SparseTensorOutConverter, SparseTensorAssembleConverter>(
1447
- typeConverter, patterns.getContext ());
1142
+ patterns.add <SparseReturnConverter, SparseTensorToDimSizeConverter,
1143
+ SparseCastConverter, SparseTensorNewConverter,
1144
+ SparseTensorAllocConverter, SparseTensorEmptyConverter,
1145
+ SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
1146
+ SparseTensorToCoordinatesConverter,
1147
+ SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1148
+ SparseTensorLoadConverter, SparseTensorInsertConverter,
1149
+ SparseTensorExpandConverter, SparseTensorCompressConverter,
1150
+ SparseTensorOutConverter, SparseTensorAssembleConverter>(
1151
+ typeConverter, patterns.getContext ());
1448
1152
patterns.add <SparseTensorConvertConverter>(typeConverter,
1449
1153
patterns.getContext (), options);
1450
1154
}
0 commit comments