@@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
324
324
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325
325
// ===----------------------------------------------------------------------===//
326
326
327
- // / External model implementation of PartialReductionInterface for LinalgOps.
327
+ // / Return an AffineMap for a partial result for the given result number,
328
+ // / assuming the partial tiling strategy is outer-reduction loop +
329
+ // / inner-parallel tile. The returned AffineMap can be used as the replacement
330
+ // / AffineMap for the inner-parallel tile linalg op for the given result number.
331
+ // /
332
+ // / The new AffineMap is the old AffineMap with reduction dimensions appended
333
+ // / at end.
334
+ static AffineMap getPartialResultAffineMap (LinalgOp linalgOp,
335
+ ArrayRef<int > reductionDims,
336
+ unsigned resultNumber) {
337
+ AffineMap map =
338
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (resultNumber));
339
+ for (int redPos : reductionDims) {
340
+ map = map.insertResult (getAffineDimExpr (redPos, linalgOp.getContext ()),
341
+ map.getNumResults ());
342
+ }
343
+ return map;
344
+ }
345
+
346
+ // / External model implementation of PartialReductionInterface for
347
+ // / LinalgOps.
328
348
template <typename LinalgOpTy>
329
349
struct LinalgOpPartialReductionInterface
330
350
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
338
358
if (linalgOp.hasPureBufferSemantics ())
339
359
return op->emitOpError (" expected operation to have tensor semantics" );
340
360
361
+ // LinalgOp implements TilingInterface.
362
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
363
+ SmallVector<OpFoldResult> shape =
364
+ llvm::map_to_vector (tilingInterfaceOp.getIterationDomain (b),
365
+ [](Range x) { return x.size ; });
366
+
367
+ SmallVector<OpFoldResult> tiledShape;
368
+ for (auto [tileSize, dimSize] : llvm::zip_equal (sizes, shape)) {
369
+ if (isZeroIndex (tileSize)) {
370
+ tiledShape.push_back (dimSize);
371
+ } else {
372
+ tiledShape.push_back (tileSize);
373
+ }
374
+ }
375
+
341
376
SmallVector<Value> inits;
342
377
for (int initIdx = 0 , e = linalgOp.getNumDpsInits (); initIdx < e;
343
378
++initIdx) {
344
- // Insert the new parallel dimension based on the index of the reduction
345
- // loops. This could be controlled by user for more flexibility.
346
379
SmallVector<Operation *, 4 > combinerOps;
347
380
if (!matchReduction (linalgOp.getRegionOutputArgs (), initIdx,
348
381
combinerOps) ||
@@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
355
388
return op->emitOpError (
356
389
" Failed to get an identity value for the reduction operation." );
357
390
358
- ArrayRef<int64_t > oldShape =
359
- linalgOp.getShape (linalgOp.getDpsInitOperand (initIdx));
360
-
361
- // Calculate the new shape, we insert the new dimensions based on the
362
- // index of the reduction dimensions.
363
- SmallVector<int64_t > newOutputShape;
364
- SmallVector<Value> dynamicDims;
365
- int64_t currReductionDims = 0 ;
366
- DenseSet<int > reductionDimsSet (reductionDims.begin (),
367
- reductionDims.end ());
368
- for (int64_t idx :
369
- llvm::seq<int64_t >(0 , oldShape.size () + reductionDims.size ())) {
370
- if (reductionDimsSet.contains (idx)) {
371
- dispatchIndexOpFoldResults (sizes[idx], dynamicDims, newOutputShape);
372
- currReductionDims++;
373
- continue ;
374
- }
375
- int64_t oldIdx = idx - currReductionDims;
376
- int64_t dim = oldShape[oldIdx];
377
- newOutputShape.push_back (dim);
378
- if (ShapedType::isDynamic (dim))
379
- dynamicDims.push_back (b.create <tensor::DimOp>(
380
- loc, linalgOp.getDpsInitOperand (initIdx)->get (), oldIdx));
391
+ // Append the new partial result dimensions.
392
+ AffineMap partialMap =
393
+ getPartialResultAffineMap (linalgOp, reductionDims, initIdx);
394
+ SmallVector<OpFoldResult> partialResultShape;
395
+ for (AffineExpr dimExpr : partialMap.getResults ()) {
396
+ auto dim = cast<AffineDimExpr>(dimExpr);
397
+ partialResultShape.push_back (tiledShape[dim.getPosition ()]);
381
398
}
382
- Value emptyTensor = b.create <tensor::EmptyOp>(
383
- loc, newOutputShape,
384
- linalgOp.getRegionOutputArgs ()[initIdx].getType (), dynamicDims);
399
+
400
+ Type elType =
401
+ getElementTypeOrSelf (linalgOp->getResult (initIdx).getType ());
402
+ Value emptyTensor =
403
+ b.create <tensor::EmptyOp>(loc, partialResultShape, elType);
385
404
Value constantOp = b.create <arith::ConstantOp>(loc, *identity);
386
405
auto identityTensor =
387
406
b.create <linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
407
426
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408
427
// this with a for range loop when we have it.
409
428
AffineMap newMap =
410
- linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (idx));
411
- for (int redPos : reductionDims) {
412
- newMap = newMap.insertResult (b.getAffineDimExpr (redPos),
413
- newMap.getNumResults ());
414
- }
429
+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
415
430
newInitMaps.push_back (newMap);
416
431
}
417
432
@@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
476
491
Location loc, ValueRange partialReduce,
477
492
ArrayRef<int > reductionDims) const {
478
493
auto linalgOp = cast<LinalgOp>(op);
479
- SmallVector<int64_t > reductionDimsInt64 (reductionDims);
480
- auto reduction = b.create <linalg::ReduceOp>(
481
- loc, partialReduce, linalgOp.getDpsInits (), reductionDimsInt64,
482
- [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483
- int64_t numInits = linalgOp.getNumDpsInits ();
484
- SmallVector<Value> yieldedValues;
485
- for (int idx : llvm::seq<int >(0 , numInits)) {
494
+
495
+ // Permute the reduction dims as permuted by the partial result map.
496
+
497
+ int64_t numInits = linalgOp.getNumDpsInits ();
498
+ SmallVector<Operation *> mergeOperations;
499
+ SmallVector<Value> replacements;
500
+ for (int idx : llvm::seq (numInits)) {
501
+ // linalg.reduce's iteration space is the tiled result's iteration space
502
+ // (and not the tiled operation's iteration space). To account for this,
503
+ // permute the reduction dimensions based on the partial result map of the
504
+ // tiled result.
505
+ AffineMap partialMap =
506
+ getPartialResultAffineMap (linalgOp, reductionDims, idx);
507
+ SmallVector<int64_t > partialReductionDims;
508
+ for (auto [resultNum, dimExpr] :
509
+ llvm::enumerate (partialMap.getResults ())) {
510
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
511
+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
512
+ partialReductionDims.push_back (resultNum);
513
+ }
514
+ }
515
+
516
+ Value partialResult = partialReduce[idx];
517
+ Value init = linalgOp.getDpsInits ()[idx];
518
+
519
+ auto reduction = b.create <linalg::ReduceOp>(
520
+ loc, partialResult, init, partialReductionDims,
521
+ [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486
522
// Get the combiner op.
487
523
SmallVector<Operation *, 4 > combinerOps;
488
524
matchReduction (linalgOp.getRegionOutputArgs (), idx, combinerOps);
489
525
Operation *clonedReductionOp = b.clone (*combinerOps[0 ]);
490
526
// Combine the input at idx and output at numInits + idx.
491
- clonedReductionOp->setOperand (0 , inputs[idx]);
492
- clonedReductionOp->setOperand (1 , inputs[numInits + idx]);
493
- // Yield.
494
- yieldedValues.push_back (clonedReductionOp->getResult (0 ));
495
- }
496
- b.create <linalg::YieldOp>(loc, yieldedValues);
497
- });
498
- return MergeResult{
499
- {reduction.getOperation ()},
500
- llvm::map_to_vector (reduction->getResults (),
501
- [](OpResult r) -> Value { return r; })};
527
+ clonedReductionOp->setOperand (0 , inputs[0 ]);
528
+ clonedReductionOp->setOperand (1 , inputs[1 ]);
529
+ b.create <linalg::YieldOp>(loc, clonedReductionOp->getResult (0 ));
530
+ });
531
+
532
+ mergeOperations.push_back (reduction);
533
+ replacements.push_back (reduction->getResult (0 ));
534
+ }
535
+
536
+ return MergeResult{mergeOperations, replacements};
537
+ }
538
+
539
+ LogicalResult getPartialResultTilePosition (
540
+ Operation *op, OpBuilder &b, unsigned resultNumber,
541
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
542
+ SmallVector<OpFoldResult> &resultOffsets,
543
+ SmallVector<OpFoldResult> &resultSizes,
544
+ ArrayRef<int > reductionDims) const {
545
+ auto linalgOp = cast<LinalgOp>(op);
546
+
547
+ AffineMap partialMap =
548
+ getPartialResultAffineMap (linalgOp, reductionDims, resultNumber);
549
+ for (AffineExpr dimExpr : partialMap.getResults ()) {
550
+ unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition ();
551
+ resultSizes.push_back (sizes[dim]);
552
+
553
+ if (llvm::find (reductionDims, dim) != reductionDims.end ()) {
554
+ // Reduction dims are reduced, and are always outputed in the same
555
+ // place. So use offset 0 for them.
556
+ resultOffsets.push_back (b.getIndexAttr (0 ));
557
+ } else {
558
+ resultOffsets.push_back (offsets[dim]);
559
+ }
560
+ }
561
+
562
+ return success ();
502
563
}
503
564
};
504
565
0 commit comments