diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index dac79111af3c9..fecd33193eb0d 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -85,6 +85,7 @@ struct SCFTilingResult { /// Values to use as replacements for the untiled op. Is the same size as the /// number of results of the untiled op. SmallVector replacements; + SmallVector extractSliceOps; }; /// Method to tile an op that implements the `TilingInterface` using @@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult { OpResult origProducer; // Original untiled producer. Value tiledAndFusedProducer; // Tile and fused producer value. SmallVector tiledOps; + SmallVector extractSliceOps; }; std::optional tileAndFuseProducerOfSlice(RewriterBase &rewriter, diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h index ca570490ccf5b..e5ed016d53fc1 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -28,9 +28,13 @@ namespace mlir { /// are returned to the caller for further transformations. /// - `tiledValues` contains the tiled value corresponding to the result of the /// untiled operation. +/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in +/// generating the `tiledOps`. Usually these are operands to the `tiledOps` +/// but they can be embedded in regions owned by `tiledOps`. struct TilingResult { SmallVector tiledOps; SmallVector tiledValues; + SmallVector extractSliceOps; }; } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b79afebfa8158..5198e0bceaa6e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder, Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + SmallVector sliceOps; + for (Value operand : tiledOperands) + if (auto sliceOp = operand.getDefiningOp()) + sliceOps.push_back(sliceOp); + + return TilingResult{ + {tiledOp}, SmallVector(tiledOp->getResults()), sliceOps}; } LogicalResult SoftmaxOp::getResultTilePosition( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index c3ab3cecfada7..f25ccc38ba0a3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -129,7 +129,13 @@ struct LinalgOpTilingInterface Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + SmallVector sliceOps; + for (Value operand : tiledOperands) + if (auto sliceOp = operand.getDefiningOp()) + sliceOps.push_back(sliceOp); + + return TilingResult{ + {tiledOp}, SmallVector(tiledOp->getResults()), sliceOps}; } /// Utility to fetch the offsets and sizes when applied as per the indexing @@ -247,7 +253,8 @@ struct LinalgOpTilingInterface return TilingResult{ tilingResult->tiledOps, - SmallVector{tilingResult->tiledValues[resultNumber]}}; + SmallVector{tilingResult->tiledValues[resultNumber]}, + tilingResult->extractSliceOps}; } /// Method to generate the tiled implementation of an operation from the tile diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a530117..fb3ec2a5fa0a8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, if (llvm::all_of(tileSizes, isZeroIndex)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = - TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()}; + TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), + /*extractSliceOps=*/{}}; return success(); } @@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // op. if (loops.empty()) { return scf::SCFTilingResult{tilingResult->tiledOps, loops, - tilingResult->tiledValues}; + tilingResult->tiledValues, + tilingResult->extractSliceOps}; } SmallVector replacements = llvm::map_to_vector( loops.front()->getResults(), [](OpResult r) -> Value { return r; }); - return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements}; + return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements, + tilingResult->extractSliceOps}; } FailureOr @@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice( ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] .set(origDestinationTensors[resultNumber]); } - return scf::SCFFuseProducerOfSliceResult{fusableProducer, - tileAndFuseResult->tiledValues[0], - tileAndFuseResult->tiledOps}; + return scf::SCFFuseProducerOfSliceResult{ + fusableProducer, tileAndFuseResult->tiledValues[0], + tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps}; } /// Reconstruct the fused producer from within the tiled-and-fused code. @@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( .getDefiningOp()) { rewriter.setInsertionPoint(tiledDestStyleOp); Value newRegionArg = newRegionIterArgs.back(); - auto destSlice = rewriter.create( - sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); unsigned resultNumber = fusableProducer.getResultNumber(); - rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { - tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); - }); + auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber] + .getDefiningOp(); + if (origSlice) { + origSlice.getSourceMutable().set(newRegionArg); + } } Block *block = rewriter.getInsertionPoint()->getBlock(); rewriter.setInsertionPoint(block->getTerminator()); @@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( // operations. If the producers of the source of the `tensor.extract_slice` // can be tiled such that the tiled value is generated in-place, that // effectively tiles + fuses the operations. - auto addCandidateSlices = [](Operation *fusedOp, + auto addCandidateSlices = [](const SmallVector &newSliceOps, std::deque &candidates) { - for (Value operand : fusedOp->getOperands()) - if (auto sliceOp = operand.getDefiningOp()) - candidates.push_back(sliceOp); + for (auto *op : newSliceOps) + candidates.push_back(llvm::cast(op)); }; std::deque candidates; - addCandidateSlices(tiledAndFusedOps.back(), candidates); + addCandidateSlices(tilingResult->extractSliceOps, candidates); OpBuilder::InsertionGuard g(rewriter); while (!candidates.empty()) { // Traverse the slices in BFS fashion. @@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); - addCandidateSlices(tiledAndFusedOp, candidates); + addCandidateSlices(fusedResult->extractSliceOps, candidates); } } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 9b2a97eb2b006..33db5a5f043f3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -99,6 +99,16 @@ static void applyPermToRange(SmallVector &offsets, applyPermutationToVector(sizes, permutation); } +static SmallVector sliceOperandsOf(Operation *op) { + SmallVector sliceOps; + for (auto operand : op->getOperands()) { + if (auto sliceOp = operand.getDefiningOp()) { + sliceOps.push_back(sliceOp); + } + } + return sliceOps; +} + struct PackOpTiling : public TilingInterface::ExternalModel { @@ -192,7 +202,8 @@ struct PackOpTiling loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{{tiledPackOp}, - SmallVector(tiledPackOp->getResults())}; + SmallVector(tiledPackOp->getResults()), + sliceOperandsOf(tiledPackOp)}; } LogicalResult @@ -440,12 +451,16 @@ struct UnPackOpTiling if (isPerfectTilingCase) return TilingResult{{tiledUnpackOp}, - SmallVector(tiledUnpackOp->getResults())}; + SmallVector(tiledUnpackOp->getResults()), + sliceOperandsOf(tiledUnpackOp)}; auto extractSlice = b.create(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; + + return TilingResult{{tiledUnpackOp}, + {extractSlice.getResult()}, + sliceOperandsOf(tiledUnpackOp)}; } LogicalResult @@ -567,7 +582,8 @@ struct UnPackOpTiling tiledOperands, op->getAttrs()); return TilingResult{{tiledUnPackOp}, - SmallVector(tiledUnPackOp->getResults())}; + SmallVector(tiledUnPackOp->getResults()), + sliceOperandsOf(tiledUnPackOp)}; } }; @@ -756,7 +772,9 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // the original data source x is not used. if (hasZeroLen) { Operation *generateOp = createGenerateOp(); - return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; + return TilingResult{{generateOp}, + {castResult(generateOp->getResult(0))}, + /*extractSliceOps=*/{}}; } // If there are dynamic dimensions: Generate an scf.if check to avoid @@ -776,11 +794,15 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, elseOp = createPadOfExtractSlice(); b.create(loc, castResult(elseOp->getResult(0))); }); - return TilingResult{{elseOp}, SmallVector(result->getResults())}; + return TilingResult{{elseOp}, + SmallVector(result->getResults()), + sliceOperandsOf(elseOp)}; } Operation *newPadOp = createPadOfExtractSlice(); - return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; + return TilingResult{{newPadOp}, + {castResult(newPadOp->getResult(0))}, + sliceOperandsOf(newPadOp)}; } void mlir::tensor::registerTilingInterfaceExternalModels(