Skip to content

Commit 91bbebc

Browse files
authored
[mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (#120465)
This PR adds a new interface method to PartialReductionOpInterface which allows it to query the result tile position for the partial result. Previously, tiling the reduction dimension with SplitReductionOuterReduction when the result has transposed parallel dimensions would produce wrong results. Other fixes that were needed to make this PR work: - Instead of ad-hoc logic to decide where to place the new reduction dimensions in the partial result based on the iteration space, the reduction dimensions are always appended to the partial result tensor. - Remove usage of PartialReductionOpInterface in Mesh dialect. The implementation was trying to just get a neutral element, but ended up trying to use PartialReductionOpInterface for it, which is not right. It was also passing the wrong sizes to it.
1 parent 07ba457 commit 91bbebc

File tree

5 files changed

+225
-91
lines changed

5 files changed

+225
-91
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
427427
/*defaultImplementation=*/[{
428428
return failure();
429429
}]
430+
>,
431+
InterfaceMethod<
432+
/*desc=*/[{
433+
Method to return the position of the partial result tile computed by
434+
the tiled operation. This is same as
435+
TilingInterface:::getResultTilePosition, but determines the result
436+
tile position for partial reduction.
437+
}],
438+
/*retType=*/"::llvm::LogicalResult",
439+
/*methodName=*/"getPartialResultTilePosition",
440+
/*args=*/(ins
441+
"::mlir::OpBuilder &":$b,
442+
"unsigned":$resultNumber,
443+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
444+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
445+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
446+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
447+
"::mlir::ArrayRef<int>":$reductionDims),
448+
/*methodBody=*/"",
449+
/*defaultImplementation=*/[{
450+
return failure();
451+
}]
430452
>
431453
];
432454
}

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105105
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106106
ArrayRef<MeshSharding> resultShardings,
107107
SymbolTableCollection &symbolTable) {
108-
for (const MeshSharding& sharding : operandShardings) {
108+
for (const MeshSharding &sharding : operandShardings) {
109109
if (sharding) {
110110
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111111
}
112112
}
113113

114-
for (const MeshSharding& sharding : resultShardings) {
114+
for (const MeshSharding &sharding : resultShardings) {
115115
if (sharding) {
116116
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117117
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129129
// the original operand.
130130
// The other processes would use the reduction operation neutral tensor.
131131
static Value createDestinationPassingStyleInitOperand(
132-
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133-
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132+
LinalgOp op, int operandNumber, Value spmdizedOperand,
133+
ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134+
ImplicitLocOpBuilder &builder) {
134135
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
135136
meshOp.getSymName(), reductionMeshAxes, builder);
136137
Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152153
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
153154
SmallVector<OpFoldResult> shape =
154155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
155-
PartialReductionOpInterface partialReductionIface =
156-
llvm::cast<PartialReductionOpInterface>(op.getOperation());
157-
assert(op->getNumResults() == 1 && "Multiple results not supported.");
158-
FailureOr<SmallVector<Value>> reductionNeutralTensor =
159-
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensor));
162-
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
156+
157+
SmallVector<Operation *> combinerOps;
158+
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159+
assert(combinerOps.size() == 1);
160+
std::optional<TypedAttr> neutralEl =
161+
arith::getNeutralElement(combinerOps[0]);
162+
163+
Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164+
neutralEl.value().getType());
165+
Value constant =
166+
builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167+
Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168+
.getResult(0);
169+
170+
builder.create<scf::YieldOp>(fill);
163171
}
164172
return ifOp.getResult(0);
165173
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178186
Value spmdizedInitOperand =
179187
spmdizationMap.lookup(op->getOperands()[operandIdx]);
180188
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
181-
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189+
op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182190
return newOperands;
183191
}
184192

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 113 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
324324
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325325
//===----------------------------------------------------------------------===//
326326

327-
/// External model implementation of PartialReductionInterface for LinalgOps.
327+
/// Return an AffineMap for a partial result for the given result number,
328+
/// assuming the partial tiling strategy is outer-reduction loop +
329+
/// inner-parallel tile. The returned AffineMap can be used as the replacement
330+
/// AffineMap for the inner-parallel tile linalg op for the given result number.
331+
///
332+
/// The new AffineMap is the old AffineMap with reduction dimensions appended
333+
/// at end.
334+
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
335+
ArrayRef<int> reductionDims,
336+
unsigned resultNumber) {
337+
AffineMap map =
338+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
339+
for (int redPos : reductionDims) {
340+
map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
341+
map.getNumResults());
342+
}
343+
return map;
344+
}
345+
346+
/// External model implementation of PartialReductionInterface for
347+
/// LinalgOps.
328348
template <typename LinalgOpTy>
329349
struct LinalgOpPartialReductionInterface
330350
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
338358
if (linalgOp.hasPureBufferSemantics())
339359
return op->emitOpError("expected operation to have tensor semantics");
340360

361+
// LinalgOp implements TilingInterface.
362+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
363+
SmallVector<OpFoldResult> shape =
364+
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
365+
[](Range x) { return x.size; });
366+
367+
SmallVector<OpFoldResult> tiledShape;
368+
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
369+
if (isZeroIndex(tileSize)) {
370+
tiledShape.push_back(dimSize);
371+
} else {
372+
tiledShape.push_back(tileSize);
373+
}
374+
}
375+
341376
SmallVector<Value> inits;
342377
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
343378
++initIdx) {
344-
// Insert the new parallel dimension based on the index of the reduction
345-
// loops. This could be controlled by user for more flexibility.
346379
SmallVector<Operation *, 4> combinerOps;
347380
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
348381
combinerOps) ||
@@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
355388
return op->emitOpError(
356389
"Failed to get an identity value for the reduction operation.");
357390

358-
ArrayRef<int64_t> oldShape =
359-
linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
360-
361-
// Calculate the new shape, we insert the new dimensions based on the
362-
// index of the reduction dimensions.
363-
SmallVector<int64_t> newOutputShape;
364-
SmallVector<Value> dynamicDims;
365-
int64_t currReductionDims = 0;
366-
DenseSet<int> reductionDimsSet(reductionDims.begin(),
367-
reductionDims.end());
368-
for (int64_t idx :
369-
llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
370-
if (reductionDimsSet.contains(idx)) {
371-
dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
372-
currReductionDims++;
373-
continue;
374-
}
375-
int64_t oldIdx = idx - currReductionDims;
376-
int64_t dim = oldShape[oldIdx];
377-
newOutputShape.push_back(dim);
378-
if (ShapedType::isDynamic(dim))
379-
dynamicDims.push_back(b.create<tensor::DimOp>(
380-
loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
391+
// Append the new partial result dimensions.
392+
AffineMap partialMap =
393+
getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
394+
SmallVector<OpFoldResult> partialResultShape;
395+
for (AffineExpr dimExpr : partialMap.getResults()) {
396+
auto dim = cast<AffineDimExpr>(dimExpr);
397+
partialResultShape.push_back(tiledShape[dim.getPosition()]);
381398
}
382-
Value emptyTensor = b.create<tensor::EmptyOp>(
383-
loc, newOutputShape,
384-
linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
399+
400+
Type elType =
401+
getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
402+
Value emptyTensor =
403+
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
385404
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
386405
auto identityTensor =
387406
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
407426
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408427
// this with a for range loop when we have it.
409428
AffineMap newMap =
410-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
411-
for (int redPos : reductionDims) {
412-
newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
413-
newMap.getNumResults());
414-
}
429+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
415430
newInitMaps.push_back(newMap);
416431
}
417432

@@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
476491
Location loc, ValueRange partialReduce,
477492
ArrayRef<int> reductionDims) const {
478493
auto linalgOp = cast<LinalgOp>(op);
479-
SmallVector<int64_t> reductionDimsInt64(reductionDims);
480-
auto reduction = b.create<linalg::ReduceOp>(
481-
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
482-
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483-
int64_t numInits = linalgOp.getNumDpsInits();
484-
SmallVector<Value> yieldedValues;
485-
for (int idx : llvm::seq<int>(0, numInits)) {
494+
495+
// Permute the reduction dims as permuted by the partial result map.
496+
497+
int64_t numInits = linalgOp.getNumDpsInits();
498+
SmallVector<Operation *> mergeOperations;
499+
SmallVector<Value> replacements;
500+
for (int idx : llvm::seq(numInits)) {
501+
// linalg.reduce's iteration space is the tiled result's iteration space
502+
// (and not the tiled operation's iteration space). To account for this,
503+
// permute the reduction dimensions based on the partial result map of the
504+
// tiled result.
505+
AffineMap partialMap =
506+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
507+
SmallVector<int64_t> partialReductionDims;
508+
for (auto [resultNum, dimExpr] :
509+
llvm::enumerate(partialMap.getResults())) {
510+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
511+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
512+
partialReductionDims.push_back(resultNum);
513+
}
514+
}
515+
516+
Value partialResult = partialReduce[idx];
517+
Value init = linalgOp.getDpsInits()[idx];
518+
519+
auto reduction = b.create<linalg::ReduceOp>(
520+
loc, partialResult, init, partialReductionDims,
521+
[&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486522
// Get the combiner op.
487523
SmallVector<Operation *, 4> combinerOps;
488524
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
489525
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
490526
// Combine the input at idx and output at numInits + idx.
491-
clonedReductionOp->setOperand(0, inputs[idx]);
492-
clonedReductionOp->setOperand(1, inputs[numInits + idx]);
493-
// Yield.
494-
yieldedValues.push_back(clonedReductionOp->getResult(0));
495-
}
496-
b.create<linalg::YieldOp>(loc, yieldedValues);
497-
});
498-
return MergeResult{
499-
{reduction.getOperation()},
500-
llvm::map_to_vector(reduction->getResults(),
501-
[](OpResult r) -> Value { return r; })};
527+
clonedReductionOp->setOperand(0, inputs[0]);
528+
clonedReductionOp->setOperand(1, inputs[1]);
529+
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
530+
});
531+
532+
mergeOperations.push_back(reduction);
533+
replacements.push_back(reduction->getResult(0));
534+
}
535+
536+
return MergeResult{mergeOperations, replacements};
537+
}
538+
539+
LogicalResult getPartialResultTilePosition(
540+
Operation *op, OpBuilder &b, unsigned resultNumber,
541+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
542+
SmallVector<OpFoldResult> &resultOffsets,
543+
SmallVector<OpFoldResult> &resultSizes,
544+
ArrayRef<int> reductionDims) const {
545+
auto linalgOp = cast<LinalgOp>(op);
546+
547+
AffineMap partialMap =
548+
getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
549+
for (AffineExpr dimExpr : partialMap.getResults()) {
550+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
551+
resultSizes.push_back(sizes[dim]);
552+
553+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
554+
// Reduction dims are reduced, and are always outputed in the same
555+
// place. So use offset 0 for them.
556+
resultOffsets.push_back(b.getIndexAttr(0));
557+
} else {
558+
resultOffsets.push_back(offsets[dim]);
559+
}
560+
}
561+
562+
return success();
502563
}
503564
};
504565

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
657657
resultOffset, resultSize);
658658
case scf::SCFTilingOptions::ReductionTilingStrategy::
659659
PartialReductionOuterReduction: {
660-
// TODO: This does not work for non identity accesses to the result tile.
661-
// The proper fix is to add a getPartialResultTilePosition method to
662-
// PartialReductionOpInterface.
663-
resultOffset =
664-
SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
665-
for (size_t i = 0; i < offsets.size(); i++) {
666-
resultSize.push_back(
667-
tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
660+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661+
if (!redOp) {
662+
return rewriter.notifyMatchFailure(
663+
op, "PartialReductionOuterReduction tiling strategy is only supported"
664+
"for operations implementing PartialReductionOpInterface");
668665
}
669-
return success();
666+
// Get reduction dimensions.
667+
// TODO: PartialReductionOpInterface should really query TilingInterface
668+
// itself and find reduction dimensions.
669+
SmallVector<int> reductionDims;
670+
for (auto [idx, iteratorType] :
671+
llvm::enumerate(op.getLoopIteratorTypes())) {
672+
if (iteratorType == utils::IteratorType::reduction)
673+
reductionDims.push_back(idx);
674+
}
675+
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676+
resultOffset, resultSize,
677+
reductionDims);
678+
}
670679
default:
671680
return rewriter.notifyMatchFailure(op,
672681
"unhandled reduction tiling strategy");
673682
}
674-
}
675683
}
676684

677685
static FailureOr<MergeResult>

0 commit comments

Comments
 (0)