diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 965ef9e203be2..aaa2dbdbcd947 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -14,6 +14,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include @@ -239,6 +240,19 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options); +/// Fuse the consumer of the source of `candidateSliceOp` by computing the +/// required slice of the consumer in-place. Note that the method +/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer +/// value but does not delete the slice operation. +struct SCFFuseConsumerOfSliceResult { + OpOperand *origConsumerOperand; // Original untiled consumer's operand. + OpOperand + *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand. + SmallVector tiledOps; +}; +FailureOr +tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr> diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index e8a09c4741043..bf6c88f7b77a8 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { @@ -22,7 +23,7 @@ namespace tensor { // Patterns //===----------------------------------------------------------------------===// -/// Pattern to swap an `tensor.extract_slice` with its producer when the +/// Method to swap an `tensor.extract_slice` with its producer when the /// producer implements the `TilingInterface`. The pattern itself does not /// provide a mechanism to control where the application happens. With use of /// transform dialect that control is done within the transform dialect. Other @@ -30,6 +31,13 @@ namespace tensor { FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); +/// Method to swap an `tensor.insert_slice` with its consumer when the +/// consumer implements the `TilingInterface`. +FailureOr +replaceInsertSliceWithTiledConsumer(OpBuilder &builder, + OffsetSizeAndStrideOpInterface sliceOp, + OpOperand &consumerOp); + //===----------------------------------------------------------------------===// // Populate functions. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 66382f29c2424..67612ffc14736 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { The method returns the operation that is the tiled implementation. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr<::mlir::TilingResult>", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, @@ -82,7 +82,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { by the tiled implementation. Expects the same `offsets` and `sizes` as used to obtain the tiled implementation of the operation. }], - /*retType=*/"LogicalResult", + /*retType=*/"::mlir::LogicalResult", /*methodName=*/"getResultTilePosition", /*args=*/(ins "OpBuilder &":$b, @@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Method to return the tile of the iteration domain where + values from the given tile of the operand are used. + }], + /*retType=*/"::mlir::LogicalResult", + /*methodName=*/"getIterationDomainTileFromOperandTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$operandNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVectorImpl &":$iterDomainOffsets, + "SmallVectorImpl &":$iterDomainSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, InterfaceMethod< /*desc=*/[{ Method to generate the code that produces a tile of the result. @@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { iteration space). - `sizes` provides the size of the tile. }], - /*retType=*/"FailureOr", + /*retType=*/"FailureOr<::mlir::TilingResult>", /*methodName=*/"generateResultTileValue", /*args=*/(ins "OpBuilder &":$b, @@ -131,6 +150,45 @@ def TilingInterface : OpInterface<"TilingInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Method to generate the tiled implementation of an operation from + operand tile position. + + NOTE: For most operations, this should be a trivial composition of + getIterationDomainTileFromOperandTile and getTiledImplementation. + + Generates the IR that computes the tiled implementation of an + operation from operand tile. The `offsets` and `sizes` + describe the tile of the operand required. This is different from + `getTiledImplementation` which generates the tiled + implementation of the operation given a tile of the + iteration space. This method generates a tiled + implementation of the operation based on the tile of the + operand required. This method enables consumer fusion by using + tile and fuse. The method returns failure if the operation + can't be tiled to generate the operand tile. In practical terms + this implies it cannot be tiled and fused with its producers. + + - `offsets` provides the offset of the tile in the coordinate system + of the original iteration space, i.e., if an iteration space + dimension had non-zero offset, it must be included in the offset + provided here (as opposed to zero-based offset "relative" to the + iteration space). + - `sizes` provides the size of the tile. + }], + /*retType=*/"FailureOr<::mlir::TilingResult>", + /*methodName=*/"getTiledImplementationFromOperandTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$operandNumber, + "ArrayRef":$offsets, + "ArrayRef":$sizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, InterfaceMethod< /*desc=*/[{ Generates the scalar implementation of the operation. @@ -142,7 +200,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { transformations are done, this method can be used to lower to scalar code that can then be lowered to LLVM or SPIR-V dialects. }], - /*retType=*/"LogicalResult", + /*retType=*/"::mlir::LogicalResult", /*methodName=*/"generateScalarImplementation", /*args=*/(ins "OpBuilder &":$b, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index bd870d4f982e5..1d9c9b26616f7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -110,7 +110,7 @@ struct LinalgOpTilingInterface })); } - // Instantiate the tiled implementation of the operation. + /// Instantiate the tiled implementation of the operation. FailureOr getTiledImplementation(Operation *op, OpBuilder &b, ArrayRef offsets, @@ -132,8 +132,63 @@ struct LinalgOpTilingInterface return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; } - // Return the details of the output tile generated by the tiled - // implementation. + /// Utility to fetch the offsets and sizes when applied as per the indexing + /// map of the linalg op. This helps in fusing the linalg op as a consumer of + /// a given slice op. + void + getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, + ArrayRef offsets, + ArrayRef sizes, + SmallVectorImpl &mappedOffsets, + SmallVectorImpl &mappedSizes) const { + unsigned numLoops = linalgOp.getNumLoops(); + auto tilingInterfaceOp = cast(linalgOp.getOperation()); + mappedOffsets.resize(numLoops); + mappedSizes.resize(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = + tilingInterfaceOp.getIterationDomain(b); + for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { + mappedOffsets[index] = value.offset; + mappedSizes[index] = value.size; + } + } + for (const auto &&[index, value] : + llvm::enumerate(indexingMap.getResults())) { + unsigned dimPosition = cast(value).getPosition(); + mappedOffsets[dimPosition] = offsets[index]; + mappedSizes[dimPosition] = sizes[index]; + } + } + + /// Method to return the position of the result tile computed by the tiled + /// operation. + LogicalResult getIterationDomainTileFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { + auto linalgOp = cast(op); + + // Check that the indexing map used for the operand is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the operand to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = + linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); + if (!indexingMap.isProjectedPermutation()) { + return op->emitError() + << "unhandled get iter domain position when operand is not " + "accessed using a permuted projection"; + } + + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + iterDomainOffsets, iterDomainSizes); + return success(); + } + + /// Return the details of the output tile generated by the tiled + /// implementation. LogicalResult getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, ArrayRef offsets, @@ -177,29 +232,16 @@ struct LinalgOpTilingInterface "unhandled tiled implementation generation when result is not " "accessed using a permuted projection"); } - - auto numLoops = linalgOp.getNumLoops(); + SmallVector mappedOffsets, mappedSizes; + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + mappedOffsets, mappedSizes); auto tilingInterfaceOp = cast(op); - SmallVector iterationTileOffsets(numLoops), - iterationTileSizes(numLoops); - if (!indexingMap.isPermutation()) { - SmallVector iterationDomain = - tilingInterfaceOp.getIterationDomain(b); - for (const auto &range : llvm::enumerate(iterationDomain)) { - iterationTileOffsets[range.index()] = range.value().offset; - iterationTileSizes[range.index()] = range.value().size; - } - } - for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { - unsigned dimPosition = - cast(resultExpr.value()).getPosition(); - iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; - iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; - } - FailureOr tilingResult = - tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets, - iterationTileSizes); + tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); + + if (failed(tilingResult)) + return failure(); + if (tilingResult->tiledOps.size() != 1) return op->emitOpError("failed to generate tiled implementation"); @@ -208,6 +250,20 @@ struct LinalgOpTilingInterface SmallVector{tilingResult->tiledValues[resultNumber]}}; } + /// Method to generate the tiled implementation of an operation from the tile + /// of the operand. + FailureOr getTiledImplementationFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes) const { + SmallVector mappedOffsets, mappedSizes; + if (failed(getIterationDomainTileFromOperandTile( + op, b, operandNumber, offsets, sizes, mappedOffsets, + mappedSizes))) { + return failure(); + } + return getTiledImplementation(op, b, mappedOffsets, mappedSizes); + } + LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, Location loc, ValueRange ivs) const { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 1a84a59ddb69d..4729e04af9d49 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -16,9 +16,11 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" @@ -1100,6 +1102,412 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( replacements}; } +//===----------------------------------------------------------------------===// +// tileAndFuseConsumerUsingSCF implementation. +//===----------------------------------------------------------------------===// + +/// A utility function that checks whether the only use of the result of a +/// tensor.insert_slice op is in a scf.yield op. +static LogicalResult +checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { + Value result = candidateSliceOp.getResult(); + Value::use_range uses = result.getUses(); + if (!llvm::hasSingleElement(uses)) { + LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); + return failure(); + } + OpOperand &operandUse = (*uses.begin()); + Operation *userOp = operandUse.getOwner(); + if (!isa(userOp)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected scf.yield to be the only user, but got -> " + << (*userOp)); + return failure(); + } + if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { + LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " + "be in the same block\n"); + return failure(); + } + return success(); +} + +/// Fetches the OpOperand of the only user (and use) of the value `val` which +/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns +/// failure otherwise. +static FailureOr getConsumerFromUses(Value val, + Block *containingOpBlock) { + // Step 1. Check that the value has exactly one use. + if (!llvm::hasSingleElement(val.getUses())) + return failure(); + // Step 2. Get uses. + OpOperand &operand = (*val.getUses().begin()); + Operation *consumerOp = operand.getOwner(); + // TODO: We have to init result of consumer before scf.for, use + // DestinationStyleOpInterface to get result shape from init for now. + // Add support for other op such as op has InferTypeOpInterface. + if (!isa(consumerOp) || + !isa(consumerOp)) + return failure(); + if (containingOpBlock != consumerOp->getBlock()) + return failure(); + return &operand; +} + +/// Fetch the untiled consumer of a scf.for's result which is yielded by a +/// tensor.insert_slice. This function makes the following assumptions : +/// 1. tensor.insert_slice has scf.yield as its only user. +/// 2. scf.for's corresponding result has only one use. +static FailureOr +getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { + if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) + return failure(); + Value sliceResult = candidateSliceOp.getResult(); + // Step 1. Fetch the corresponding output. + OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); + unsigned resultNumber = yieldOpOperand.getOperandNumber(); + // Step 2. Check containing op is scf.for. + Operation *containingOp = candidateSliceOp->getParentOp(); + auto forOp = dyn_cast(containingOp); + if (!forOp) + return failure(); + Value resultingValue = forOp->getResult(resultNumber); + + return getConsumerFromUses(resultingValue, containingOp->getBlock()); +} + +/// Fetch the first untiled consumer of a scf.forall's result which is yielded +/// by a tensor.parallel_insert_slice. +static FailureOr +getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { + // Step 1. Fetch the corresponding output + Value sliceDest = candidateSliceOp.getDest(); + auto iterArg = dyn_cast(sliceDest); + if (!iterArg) + return failure(); + Operation *containingOp = iterArg.getOwner()->getParentOp(); + if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) + return failure(); + // Step 2. Check that the containing op is scf.forall. + auto forallOp = dyn_cast(containingOp); + if (!forallOp) + return failure(); + Value resultingValue = + forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); + + return getConsumerFromUses(resultingValue, containingOp->getBlock()); +} + +/// This utility currently checks whether the loop either :- +/// 1. Yields exactly one result. +/// 2. Has consumer op as its first user and other users to be in the same +/// containing block as that of consumer op's. Currently we clone the loop op +/// right before the consumer op in order to maintain a valid def-use chain. +/// This utility thus helps ensuring that no invalid IR is formed due to the +/// same. +static LogicalResult checkAssumptionForLoop(Operation *loopOp, + Operation *consumerOp) { + // Check if the loop op yields one result. + if (loopOp->getNumResults() == 1) + return success(); + // Check if the consumerOp is the first user of the loopOp and if other users + // are in the same containing block as that of consumer op's. + Block *parentBlock = consumerOp->getBlock(); + for (Operation *userOp : loopOp->getUsers()) { + if (userOp == consumerOp) + continue; + if (parentBlock != userOp->getBlock() || + !consumerOp->isBeforeInBlock(userOp)) + return failure(); + } + return success(); +} + +/// A utility to fetch an untiled consumer of +/// tensor.insert_slice/tensor.parallel_insert_slice. +static FailureOr getUntiledConsumerFromSlice(Operation *sliceOp) { + if (auto insertSlice = dyn_cast(sliceOp)) { + return getUntiledConsumerFromSlice(insertSlice); + } else if (auto parallelInsertSlice = + dyn_cast(sliceOp)) { + return getUntiledConsumerFromSlice(parallelInsertSlice); + } else { + return failure(); + } +} + +/// After fusing consumer into scf.for we want to modify the scf.yield operation +/// to reflect the same by returning the values yielded by the tiled consumer. +static void +fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, + TilingResult &tilingResult, + ArrayRef> &resultOffsets, + ArrayRef> &resultSizes, + ArrayRef bbArgs) { + scf::YieldOp oldTerminatorOp = + cast(newForOp.getBody()->getTerminator()); + unsigned totalOldResults = oldTerminatorOp->getNumResults(); + unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults(); + SmallVector newYieldOperands; + newYieldOperands.reserve(totalOldResults + totalTiledResults); + for (auto oldResult : oldTerminatorOp.getResults()) { + newYieldOperands.push_back(oldResult); + } + rewriter.setInsertionPointAfter(oldTerminatorOp); + Location loc = newForOp.getLoc(); + for (auto [tiledResult, bbArg, resultOffset, resultSize] : + llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs, + resultOffsets, resultSizes)) { + SmallVector strides(resultOffset.size(), + rewriter.getIndexAttr(1)); + Value newInsertSliceOp = rewriter.create( + loc, tiledResult, bbArg, resultOffset, resultSize, strides); + newYieldOperands.push_back(newInsertSliceOp); + } + rewriter.create(loc, newYieldOperands); + rewriter.eraseOp(oldTerminatorOp); +} + +/// After fusing consumer into scf.forall we want to yield each of the resulting +/// values by the tiled consumer within scf.forall.in_parallel region. +static void +fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, + SmallVector tiledResults, + ArrayRef> &resultOffsets, + ArrayRef> &resultSizes, + ArrayRef bbArgs) { + scf::InParallelOp newTerminatorOp = newForallOp.getTerminator(); + rewriter.setInsertionPointToStart(newTerminatorOp.getBody()); + Location firstYieldOpLoc = + (*(newTerminatorOp.getYieldingOps().begin())).getLoc(); + for (auto [tiledResult, bbArg, resultOffset, resultSize] : + llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) { + SmallVector strides(resultOffset.size(), + rewriter.getIndexAttr(1)); + rewriter.create( + firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides); + } +} + +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr +mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, + Operation *candidateSliceOp) { + if (!isa( + candidateSliceOp)) + return failure(); + + bool isInsertSliceOp = isa(candidateSliceOp); + + // 1. Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + FailureOr maybeConsumerOpOperand = + getUntiledConsumerFromSlice(candidateSliceOp); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSliceOp, + "could not fetch consumer to fuse"); + } + OpOperand *consumerOpOperand = *maybeConsumerOpOperand; + Operation *consumerOp = consumerOpOperand->getOwner(); + unsigned operandNumber = consumerOpOperand->getOperandNumber(); + unsigned resultNumber = 0; + if (auto producerResult = dyn_cast(consumerOpOperand->get())) { + resultNumber = producerResult.getResultNumber(); + } else { + return rewriter.notifyMatchFailure( + consumerOp, "consumer op's operand doesn't seem to be an OpResult"); + } + + Operation *oldLoopOp = nullptr; + SmallVector newOuts; + Block *oldLoopBody = nullptr; + unsigned initSize = 0; + unsigned rank = 1; + if (isInsertSliceOp) { + auto forOp = candidateSliceOp->getParentOfType(); + oldLoopOp = forOp; + llvm::append_range(newOuts, forOp.getInits()); + oldLoopBody = forOp.getBody(); + initSize = forOp.getInits().size(); + } else { + auto forallOp = candidateSliceOp->getParentOfType(); + oldLoopOp = forallOp; + llvm::append_range(newOuts, forallOp.getOutputs()); + oldLoopBody = forallOp.getBody(); + initSize = forallOp.getOutputs().size(); + rank = forallOp.getRank(); + } + + if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) { + return rewriter.notifyMatchFailure( + oldLoopOp, "containing loop op should either yield just one value or " + "have the consumer op as its first user"); + } + + OpBuilder::InsertionGuard g(rewriter); + + // 2. Check consumer is not using scf loop's output as init. + auto dstOp = cast(consumerOp); + SmallVector dpsInits = + llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); + if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) { + return rewriter.notifyMatchFailure( + consumerOp, + "consumer op taking the result of scf.for as init is not supported"); + } + newOuts.append(dpsInits); + + Location loc = oldLoopOp->getLoc(); + + // 3. Create new scf loop op. + rewriter.setInsertionPoint(consumerOp); + Operation *newLoopOp = nullptr; + Block *newLoopBody = nullptr; + if (isInsertSliceOp) { + auto forOp = cast(oldLoopOp); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newOuts); + newLoopOp = newForOp; + newLoopBody = newForOp.getBody(); + } else { + auto forallOp = cast(oldLoopOp); + auto newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newOuts, forallOp.getMapping()); + newLoopOp = newForallOp; + rewriter.eraseOp(newForallOp.getTerminator()); + newLoopBody = newForallOp.getBody(); + } + + // 4. Move the loop body to the new op. + unsigned oldNumArguments = oldLoopBody->getNumArguments(); + rewriter.mergeBlocks(oldLoopBody, newLoopBody, + newLoopBody->getArguments().take_front(oldNumArguments)); + + // 5. Set insertion point before terminator op of the loop and create a new + // tensor.insert_slice. In the scf.for case this is a clone of the + // candidateSliceOp whereas in the scf.forall case this is created from the + // operands of tensor.parallel_insert_slice. + tensor::InsertSliceOp clonedInsertSliceOp; + if (auto sliceOp = + dyn_cast(candidateSliceOp)) { + auto newForallOp = cast(newLoopOp); + rewriter.setInsertionPoint(newForallOp.getTerminator()); + clonedInsertSliceOp = rewriter.create( + loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); + } else { + rewriter.setInsertionPoint(candidateSliceOp); + clonedInsertSliceOp = + cast(rewriter.clone(*candidateSliceOp)); + } + + // 6.a. Clone consumer op. + auto newForOpBlockArgsForConsumerDest = + newLoopBody->getArguments().drop_front(oldNumArguments); + auto clonedConsumerOp = cast(cloneOpAndUpdateDestinationArgs( + rewriter, consumerOp, newForOpBlockArgsForConsumerDest)); + + // 6.b. Replace all uses of the loop result with the result of the cloned + // tensor.insert_slice. + OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); + rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { + operandToReplace.set(clonedInsertSliceOp.getResult()); + }); + + // 7 - Perform tiling of the cloned consumer and replace the operand at + // `operandNumber` with the source of the cloned tensor.insert_slice op. + auto ossSliceOp = + cast(clonedInsertSliceOp.getOperation()); + FailureOr tileAndFuseResult = + tensor::replaceInsertSliceWithTiledConsumer( + rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); + if (failed(tileAndFuseResult)) { + return failure(); + } + rewriter.replaceAllUsesWith( + tileAndFuseResult->tiledOps[0]->getOperand(operandNumber), + clonedInsertSliceOp.getSource()); + + // 8 - Extract offset/sizes/strides required to create the + // tensor.insert_slice/parallel_insert_slice for each result of the consumer. + SmallVector offsets = ossSliceOp.getMixedOffsets(); + SmallVector sizes = ossSliceOp.getMixedSizes(); + SmallVector strides = ossSliceOp.getMixedStrides(); + + // 9. Check all insert stride is 1. + if (llvm::any_of(strides, [](OpFoldResult stride) { + return !isConstantIntValue(stride, 1); + })) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "containingOp's result yield with stride"); + } + + // 10. Try to get iter domain position from input position. + SmallVector iterDomainOffsets, iterDomainSizes; + if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( + rewriter, operandNumber, offsets, sizes, iterDomainOffsets, + iterDomainSizes))) { + return rewriter.notifyMatchFailure( + clonedConsumerOp, "can't get iter domain position from input position"); + } + + // 11. Try to fetch the offset and size for all results of the cloned + // consumer. This would then be used to form the corresponding + // tensor.insert_slice/parallel_insert_slice later. + unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults(); + SmallVector> resultOffsets( + totalNumResultsOfConsumer); + SmallVector> resultSizes(totalNumResultsOfConsumer); + for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) { + if (failed(clonedConsumerOp.getResultTilePosition( + rewriter, idx, iterDomainOffsets, iterDomainSizes, + resultOffsets[idx], resultSizes[idx]))) { + return rewriter.notifyMatchFailure( + clonedConsumerOp, + "can't get result domain position from iter domain position"); + } + } + + auto arrayRefOffsets = ArrayRef>(resultOffsets); + auto arrayRefSizes = ArrayRef>(resultSizes); + if (isInsertSliceOp) { + auto newForOp = cast(newLoopOp); + fixTerminatorSCFYield( + rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes, + newForOp.getBody()->getArguments().drop_front(1 + initSize)); + } else { + auto newForallOp = cast(newLoopOp); + fixTerminatorSCFInParallel( + rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(), + arrayRefOffsets, arrayRefSizes, + newForallOp.getBody()->getArguments().drop_front(rank + initSize)); + } + + // 12. Replace the result of scf loop and consumer op with new loop's results. + for (auto &&[oldResult, newResult] : + llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) { + rewriter.replaceAllUsesWith(oldResult, newResult); + } + + for (auto &&[oldResult, newResult] : + llvm::zip(consumerOp->getResults(), + newLoopOp->getResults().drop_front(initSize))) { + rewriter.replaceAllUsesWith(oldResult, newResult); + } + + // 13. Need to erase the old scf loop and the cloned consumer op. + rewriter.eraseOp(oldLoopOp); + rewriter.eraseOp(clonedConsumerOp); + + return scf::SCFFuseConsumerOfSliceResult{ + consumerOpOperand, + &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), + tileAndFuseResult->tiledOps}; +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index d25efcf50ec56..9b2a97eb2b006 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -469,6 +469,106 @@ struct UnPackOpTiling return failure(); return tilingResult.value(); } + + /// Method to return the position of iteration domain tile computed by the + /// tiled operation. + LogicalResult getIterationDomainTileFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { + auto unPackOp = cast(op); + Location loc = unPackOp.getLoc(); + + int64_t numTiles = unPackOp.getInnerDimsPos().size(); + auto destOffsets = offsets.drop_back(numTiles); + auto destSizes = sizes.drop_back(numTiles); + // The tiling is applied on interchanged dimensions. We have to undo the + // interchange to map sizes and offsets to the original input. + int64_t outputRank = unPackOp.getDestRank(); + SmallVector origOffsets(destOffsets.begin(), + destOffsets.end()); + SmallVector origSizes(destSizes.begin(), destSizes.end()); + applyPermToRange(origOffsets, origSizes, + invertPermutationVector(unPackOp.getOuterDimsPerm())); + + DenseMap dimAndTileMapping = + unPackOp.getDimAndTileMapping(); + + for (auto dim : llvm::seq(0, outputRank)) { + using AV = affine::AffineValueExpr; + affine::AffineBuilder ab(b, loc); + AffineExpr dim0, dim1, sym; + bindDims(b.getContext(), dim0, dim1); + bindSymbols(b.getContext(), sym); + if (dimAndTileMapping.count(dim)) { + // If the data dimension is tiled, the i-th index is the product of + // offset_i and tile_i, and the i-th size is the product of sizes_i and + // tile_i. + auto avOffset = AV(dim0).bind(origOffsets[dim]); + auto avSize = AV(dim0).bind(origSizes[dim]); + auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); + resultOffsets.push_back(ab.mul(avOffset, avTileSize)); + resultSizes.push_back(ab.mul(avSize, avTileSize)); + } else { + resultOffsets.push_back(origOffsets[dim]); + resultSizes.push_back(origSizes[dim]); + } + } + return success(); + } + + /// Method to return the tiled implementation of tensor.unpack as a consumer. + FailureOr getTiledImplementationFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes) const { + auto unPackOp = cast(op); + // tensor.unpack op is fusible (as a consumer) only if inner dims are not + // tiled. + int64_t numTiles = unPackOp.getInnerDimsPos().size(); + for (auto iter : + llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { + if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) + return failure(); + } + + Location loc = unPackOp.getLoc(); + + // Fetch offset/size for creating the slice of the dest operand of + // unpack op. + SmallVector outputOffsets, outputSizes; + if (failed(getIterationDomainTileFromOperandTile( + op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, + outputSizes))) + return failure(); + + auto oneAttr = b.getI64IntegerAttr(1); + int64_t outputRank = unPackOp.getDestRank(); + SmallVector strides(outputRank, oneAttr); + + SmallVector tiledOperands; + // Create slice of the dest operand. + auto extractDestSlice = b.create( + loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); + tiledOperands.push_back(extractDestSlice); + + SmallVector inputOffsets, inputSizes; + strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); + // Create slice of the source operand. + auto extractSourceSlice = b.create( + loc, unPackOp.getSource(), offsets, sizes, strides); + tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); + for (auto tile : unPackOp.getInnerTiles()) + tiledOperands.push_back(tile); + + // Create tiled unpack op. + Operation *tiledUnPackOp = + b.create(loc, TypeRange{extractDestSlice.getType()}, + tiledOperands, op->getAttrs()); + + return TilingResult{{tiledUnPackOp}, + SmallVector(tiledUnPackOp->getResults())}; + } }; } // namespace diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 40d79c2053817..858adfc436164 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -40,3 +40,26 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( return *tiledResult; } + +FailureOr tensor::replaceInsertSliceWithTiledConsumer( + OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, + OpOperand &consumer) { + auto consumerOp = dyn_cast(consumer.getOwner()); + if (!consumerOp) + return failure(); + + // `TilingInterface` currently only supports strides being 1. + if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 1); + })) + return failure(); + + FailureOr tiledResult = + consumerOp.getTiledImplementationFromOperandTile( + builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes()); + if (failed(tiledResult)) + return failure(); + + return *tiledResult; +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir new file mode 100644 index 0000000000000..400b558e37fcd --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -0,0 +1,317 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %yield + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_operand_2 = tensor.empty() : tensor<64x64xf32> + %out_operand_3 = tensor.empty() : tensor<64x64xf32> + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn} +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %out_operand_4 = tensor.empty() : tensor<64xf32> + %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.subf %out_0, %13 : f32 + %15 = arith.addf %out_1, %in : f32 + linalg.yield %14, %15 : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + return %2#1 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %yield + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#3 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = tensor.empty() : tensor<64x64xf32> + %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.subf %out, %6 : f32 + %8 = arith.addf %out_1, %in : f32 + linalg.yield %7, %8 : f32, f32 + } -> (tensor<64x64xf32>, tensor<64x64xf32>) + %5 = tensor.empty() : tensor<2048xf32> + %unpack = tensor.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> + return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32> +// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2048xf32> + %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> + return %unpack : tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK: func.func @fuse_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1] +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 335db1a61f476..833fb3cc65b81 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -160,6 +160,59 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, : DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestFuseConsumerOp +//===----------------------------------------------------------------------===// + +/// Apply fusing of consumer transformation to all payload ops and store both +/// the original consumer operation as well as the fused consumer operation. +template +static LogicalResult +applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, + Range &&payloadOps, TransformResults &transformResults) { + SmallVector originalConsumerOps; + SmallVector fusedConsumerOps; + + for (Operation *target : payloadOps) { + rewriter.setInsertionPoint(target); + + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumerOfSlice(rewriter, target); + + if (failed(fuseConsumerResults)) + return failure(); + + // Report back the relevant handles to the transform op. + originalConsumerOps.push_back( + fuseConsumerResults->origConsumerOperand->getOwner()); + fusedConsumerOps.push_back( + fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); + } + + transformResults.set(transformOp->getOpResult(0), originalConsumerOps); + transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); + return success(); +} + +DiagnosedSilenceableFailure +transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + LogicalResult result = + applyFuseConsumer(rewriter, getOperation(), + state.getPayloadOps(getTarget()), transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestFuseConsumerOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getConsumer(), effects); + producesHandle(getFusedConsumer(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // TestTileUsingForallOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index ef42375e5286d..d55d746bd6aa9 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses the consumer of the operation pointed to by the target handle + using the options provided as attributes. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$consumer, + TransformHandleTypeInterface:$fused_consumer); + + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; +} + def TestTileUsingForallOp : Op, DeclareOpInterfaceMethods,