diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e404c01010a32..04624638e14c0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1481,6 +1481,50 @@ static FailureOr getConsumerFromUses(Value val, return &operand; } +/// Find the perfectly nested loops outside of given loop(included) sorted from +/// outer to inner. +/// +/// E.g. +/// +/// ``` +/// %0 = scf.for() +/// %1 = scf.for() +/// %2 = scf.for() +/// %3 = ... +/// yield %3 +/// yield %2 +/// yield %1 +/// ``` +/// +/// This function will return three perfectly nested loops: %0 + %1 + %2, when +/// target inner loop is %2. +static SmallVector +getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { + SmallVector nestLoops = {loop}; + auto outerLoop = dyn_cast(loop->getParentOp()); + + // Check if it is the ForOp that yield the result of inner loop. + auto isForOpYieldResultOfInnerLoop = + [](scf::ForOp outerLoop) -> LogicalResult { + Block *body = outerLoop.getBody(); + if (!llvm::hasSingleElement(body->without_terminator())) + return failure(); + auto yieldOp = cast(body->getTerminator()); + auto innerForOp = dyn_cast(body->front()); + if (!innerForOp) + return failure(); + // All of innerForOp results should be yielded. + return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); + }; + + while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { + nestLoops.push_back(outerLoop); + outerLoop = dyn_cast(outerLoop->getParentOp()); + } + // sorted from outer to inner + return {nestLoops.rbegin(), nestLoops.rend()}; +} + /// Fetch the untiled consumer of a scf.for's result which is yielded by a /// tensor.insert_slice. This function makes the following assumptions : /// 1. tensor.insert_slice has scf.yield as its only user. @@ -1498,9 +1542,10 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { auto forOp = dyn_cast(containingOp); if (!forOp) return failure(); - Value resultingValue = forOp->getResult(resultNumber); + scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); + Value resultingValue = topLevelForOp->getResult(resultNumber); - return getConsumerFromUses(resultingValue, containingOp->getBlock()); + return getConsumerFromUses(resultingValue, topLevelForOp->getBlock()); } /// Fetch the first untiled consumer of a scf.forall's result which is yielded @@ -1563,59 +1608,6 @@ static FailureOr getUntiledConsumerFromSlice(Operation *sliceOp) { } } -/// After fusing consumer into scf.for we want to modify the scf.yield operation -/// to reflect the same by returning the values yielded by the tiled consumer. -static void -fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, - TilingResult &tilingResult, - ArrayRef> &resultOffsets, - ArrayRef> &resultSizes, - ArrayRef bbArgs) { - scf::YieldOp oldTerminatorOp = - cast(newForOp.getBody()->getTerminator()); - unsigned totalOldResults = oldTerminatorOp->getNumResults(); - unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults(); - SmallVector newYieldOperands; - newYieldOperands.reserve(totalOldResults + totalTiledResults); - for (auto oldResult : oldTerminatorOp.getResults()) { - newYieldOperands.push_back(oldResult); - } - rewriter.setInsertionPointAfter(oldTerminatorOp); - Location loc = newForOp.getLoc(); - for (auto [tiledResult, bbArg, resultOffset, resultSize] : - llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs, - resultOffsets, resultSizes)) { - SmallVector strides(resultOffset.size(), - rewriter.getIndexAttr(1)); - Value newInsertSliceOp = rewriter.create( - loc, tiledResult, bbArg, resultOffset, resultSize, strides); - newYieldOperands.push_back(newInsertSliceOp); - } - rewriter.create(loc, newYieldOperands); - rewriter.eraseOp(oldTerminatorOp); -} - -/// After fusing consumer into scf.forall we want to yield each of the resulting -/// values by the tiled consumer within scf.forall.in_parallel region. -static void -fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, - SmallVector tiledResults, - ArrayRef> &resultOffsets, - ArrayRef> &resultSizes, - ArrayRef bbArgs) { - scf::InParallelOp newTerminatorOp = newForallOp.getTerminator(); - rewriter.setInsertionPointToStart(newTerminatorOp.getBody()); - Location firstYieldOpLoc = - (*(newTerminatorOp.getYieldingOps().begin())).getLoc(); - for (auto [tiledResult, bbArg, resultOffset, resultSize] : - llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) { - SmallVector strides(resultOffset.size(), - rewriter.getIndexAttr(1)); - rewriter.create( - firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides); - } -} - /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr @@ -1646,81 +1638,63 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, consumerOp, "consumer op's operand doesn't seem to be an OpResult"); } - Operation *oldLoopOp = nullptr; - SmallVector newOuts; - Block *oldLoopBody = nullptr; - unsigned initSize = 0; - unsigned rank = 1; + // There are two possible cases regarding `oldLoopOp` here: + // 1. single `scf.forall` or `scf.for`. + // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the + // top-level loop is the outer-most one of these nested loops. + LoopLikeOpInterface innerMostLoop = + candidateSliceOp->getParentOfType(); + SmallVector nestedLoops; if (isInsertSliceOp) { - auto forOp = candidateSliceOp->getParentOfType(); - oldLoopOp = forOp; - llvm::append_range(newOuts, forOp.getInits()); - oldLoopBody = forOp.getBody(); - initSize = forOp.getInits().size(); + nestedLoops = llvm::map_to_vector( + getPerfectlyNestedLoopsOutsideOf( + cast(innerMostLoop.getOperation())), + [](scf::ForOp forOp) { + return cast(forOp.getOperation()); + }); } else { - auto forallOp = candidateSliceOp->getParentOfType(); - oldLoopOp = forallOp; - llvm::append_range(newOuts, forallOp.getOutputs()); - oldLoopBody = forallOp.getBody(); - initSize = forallOp.getOutputs().size(); - rank = forallOp.getRank(); + nestedLoops = {innerMostLoop}; } - if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) { + LoopLikeOpInterface outerMostLoop = nestedLoops.front(); + + if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) { return rewriter.notifyMatchFailure( - oldLoopOp, "containing loop op should either yield just one value or " - "have the consumer op as its first user"); + outerMostLoop, + "containing loop op should either yield just one value or " + "have the consumer op as its first user"); } OpBuilder::InsertionGuard g(rewriter); // 2. Check consumer is not using scf loop's output as init. - auto dstOp = cast(consumerOp); + auto dstOp = dyn_cast(consumerOp); + if (!dstOp) + return rewriter.notifyMatchFailure(consumerOp, + "consumer op is not DPS operation"); SmallVector dpsInits = llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); - if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) { + if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { return rewriter.notifyMatchFailure( consumerOp, "consumer op taking the result of scf.for as init is not supported"); } - newOuts.append(dpsInits); - - Location loc = oldLoopOp->getLoc(); + SmallVector newInits = dpsInits; - // 3. Create new scf loop op. - rewriter.setInsertionPoint(consumerOp); - Operation *newLoopOp = nullptr; - Block *newLoopBody = nullptr; - if (isInsertSliceOp) { - auto forOp = cast(oldLoopOp); - auto newForOp = rewriter.create(loc, forOp.getLowerBound(), - forOp.getUpperBound(), - forOp.getStep(), newOuts); - newLoopOp = newForOp; - newLoopBody = newForOp.getBody(); - } else { - auto forallOp = cast(oldLoopOp); - auto newForallOp = rewriter.create( - loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), - forallOp.getMixedStep(), newOuts, forallOp.getMapping()); - newLoopOp = newForallOp; - rewriter.eraseOp(newForallOp.getTerminator()); - newLoopBody = newForallOp.getBody(); - } + Location loc = outerMostLoop->getLoc(); - // 4. Move the loop body to the new op. - unsigned oldNumArguments = oldLoopBody->getNumArguments(); - rewriter.mergeBlocks(oldLoopBody, newLoopBody, - newLoopBody->getArguments().take_front(oldNumArguments)); + // 3. Move the whole loop structure right before consumer Op, the dominance + // should be already ensured by `checkAssumptionForLoop`. + rewriter.moveOpBefore(outerMostLoop, consumerOp); - // 5. Set insertion point before terminator op of the loop and create a new + // 4. Set insertion point before terminator op of the loop and create a new // tensor.insert_slice. In the scf.for case this is a clone of the // candidateSliceOp whereas in the scf.forall case this is created from the // operands of tensor.parallel_insert_slice. tensor::InsertSliceOp clonedInsertSliceOp; if (auto sliceOp = dyn_cast(candidateSliceOp)) { - auto newForallOp = cast(newLoopOp); + auto newForallOp = cast(innerMostLoop.getOperation()); rewriter.setInsertionPoint(newForallOp.getTerminator()); clonedInsertSliceOp = rewriter.create( loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), @@ -1731,20 +1705,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, cast(rewriter.clone(*candidateSliceOp)); } - // 6.a. Clone consumer op. - auto newForOpBlockArgsForConsumerDest = - newLoopBody->getArguments().drop_front(oldNumArguments); - auto clonedConsumerOp = cast(cloneOpAndUpdateDestinationArgs( - rewriter, consumerOp, newForOpBlockArgsForConsumerDest)); + // 5.a. Clone consumer op. + auto clonedConsumerOp = cast(rewriter.clone(*consumerOp)); - // 6.b. Replace all uses of the loop result with the result of the cloned + // 5.b. Replace all uses of the loop result with the result of the cloned // tensor.insert_slice. OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { operandToReplace.set(clonedInsertSliceOp.getResult()); }); - // 7 - Perform tiling of the cloned consumer and replace the operand at + // 6. Perform tiling of the cloned consumer and replace the operand at // `operandNumber` with the source of the cloned tensor.insert_slice op. auto ossSliceOp = cast(clonedInsertSliceOp.getOperation()); @@ -1754,79 +1725,105 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, if (failed(tileAndFuseResult)) { return failure(); } - rewriter.replaceAllUsesWith( - tileAndFuseResult->tiledOps[0]->getOperand(operandNumber), - clonedInsertSliceOp.getSource()); - - // 8 - Extract offset/sizes/strides required to create the - // tensor.insert_slice/parallel_insert_slice for each result of the consumer. - SmallVector offsets = ossSliceOp.getMixedOffsets(); - SmallVector sizes = ossSliceOp.getMixedSizes(); - SmallVector strides = ossSliceOp.getMixedStrides(); - - // 9. Check all insert stride is 1. - if (llvm::any_of(strides, [](OpFoldResult stride) { - return !isConstantIntValue(stride, 1); - })) { - return rewriter.notifyMatchFailure( - candidateSliceOp, "containingOp's result yield with stride"); - } + auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]); + rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), + clonedInsertSliceOp.getSource()); - // 10. Try to get iter domain position from input position. - SmallVector iterDomainOffsets, iterDomainSizes; - if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( - rewriter, operandNumber, offsets, sizes, iterDomainOffsets, - iterDomainSizes))) { - return rewriter.notifyMatchFailure( - clonedConsumerOp, "can't get iter domain position from input position"); - } + // 7. Reconstruct [nested] loop with new inits. + YieldTiledValuesFn newYieldValuesFn = + [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, + ValueRange newRegionIterArgs, SmallVector &tiledResult, + SmallVector> &tiledOffset, + SmallVector> &tiledSizes) -> LogicalResult { + OpBuilder::InsertionGuard g(innerRewriter); + // 8. Set inner insertPoint right before tiled consumer op. + innerRewriter.setInsertionPoint(tiledConsumerOp); - // 11. Try to fetch the offset and size for all results of the cloned - // consumer. This would then be used to form the corresponding - // tensor.insert_slice/parallel_insert_slice later. - unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults(); - SmallVector> resultOffsets( - totalNumResultsOfConsumer); - SmallVector> resultSizes(totalNumResultsOfConsumer); - for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) { - if (failed(clonedConsumerOp.getResultTilePosition( - rewriter, idx, iterDomainOffsets, iterDomainSizes, - resultOffsets[idx], resultSizes[idx]))) { + SmallVector offsets = ossSliceOp.getMixedOffsets(); + SmallVector sizes = ossSliceOp.getMixedSizes(); + SmallVector strides = ossSliceOp.getMixedStrides(); + + // 9. Check all insert stride is 1. + if (llvm::any_of(strides, [](OpFoldResult stride) { + return !isConstantIntValue(stride, 1); + })) { return rewriter.notifyMatchFailure( - clonedConsumerOp, - "can't get result domain position from iter domain position"); + candidateSliceOp, "containingOp's result yield with stride"); } - } - auto arrayRefOffsets = ArrayRef>(resultOffsets); - auto arrayRefSizes = ArrayRef>(resultSizes); - if (isInsertSliceOp) { - auto newForOp = cast(newLoopOp); - fixTerminatorSCFYield( - rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes, - newForOp.getBody()->getArguments().drop_front(1 + initSize)); - } else { - auto newForallOp = cast(newLoopOp); - fixTerminatorSCFInParallel( - rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(), - arrayRefOffsets, arrayRefSizes, - newForallOp.getBody()->getArguments().drop_front(rank + initSize)); - } + // 10. Try to get iter domain position from input position. + SmallVector iterDomainOffsets, iterDomainSizes; + if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile( + rewriter, operandNumber, offsets, sizes, iterDomainOffsets, + iterDomainSizes))) { + return rewriter.notifyMatchFailure( + tiledConsumerOp, + "can't get iter domain position from input position"); + } - // 12. Replace the result of scf loop and consumer op with new loop's results. - for (auto &&[oldResult, newResult] : - llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) { - rewriter.replaceAllUsesWith(oldResult, newResult); + // 11. Try to fetch the offset and size for all results of the cloned + // consumer. This would then be used to form the corresponding + // tensor.insert_slice/parallel_insert_slice later. + unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); + SmallVector> resultOffsets( + totalNumResultsOfConsumer); + SmallVector> resultSizes( + totalNumResultsOfConsumer); + for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { + if (failed(tiledConsumerOp.getResultTilePosition( + rewriter, idx, iterDomainOffsets, iterDomainSizes, + resultOffsets[idx], resultSizes[idx]))) { + return rewriter.notifyMatchFailure( + tiledConsumerOp, + "can't get result domain position from iter domain position"); + } + } + + // 12. Create `extract_slice` for `iter_args` for DPS operation if + // necessary. + if (auto tiledDestStyleOp = dyn_cast( + tiledConsumerOp.getOperation())) { + rewriter.setInsertionPoint(tiledDestStyleOp); + for (const auto &&[index, newRegionArg] : + llvm::enumerate(newRegionIterArgs)) { + auto destSlice = rewriter.create( + loc, newRegionArg, resultOffsets[index], resultSizes[index], + SmallVector(resultOffsets[index].size(), + rewriter.getIndexAttr(1))); + rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { + tiledDestStyleOp.getDpsInitsMutable()[index].set(destSlice); + }); + } + } + + // 13. Prepare tiled offset and sizes for later `insert_slice` creation by + // caller. + Block *block = rewriter.getInsertionPoint()->getBlock(); + rewriter.setInsertionPoint(block->getTerminator()); + for (const auto &&[index, result] : + llvm::enumerate(tiledConsumerOp->getResults())) { + tiledResult.push_back(result); + tiledOffset.emplace_back(resultOffsets[index]); + tiledSizes.emplace_back(resultSizes[index]); + } + return success(); + }; + // 14. Add new inits to [nested] loops. + if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, + newYieldValuesFn))) { + return rewriter.notifyMatchFailure(tiledConsumerOp, + "unable to add new inits to nest loop"); } - for (auto &&[oldResult, newResult] : - llvm::zip(consumerOp->getResults(), - newLoopOp->getResults().drop_front(initSize))) { + // 15. Replace the result of scf loop and consumer op with new loop's results. + + for (auto &&[oldResult, newResult] : llvm::zip( + consumerOp->getResults(), + nestedLoops.front()->getResults().take_back(newInits.size()))) { rewriter.replaceAllUsesWith(oldResult, newResult); } - // 13. Need to erase the old scf loop and the cloned consumer op. - rewriter.eraseOp(oldLoopOp); + // 16. Need to erase the old scf loop and the cloned consumer op. rewriter.eraseOp(clonedConsumerOp); return scf::SCFFuseConsumerOfSliceResult{ diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 83c5ec8d7342c..fdefdcc453ae7 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -109,9 +109,9 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : // CHECK-SAME: outs(%[[SLICE_OUT]] : // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: } // CHECK: } // CHECK: return %[[FINAL_RESULT]]#2 : @@ -248,10 +248,10 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : // CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: } // CHECK: } // CHECK: %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32> @@ -310,8 +310,8 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] // CHECK-SAME: into %[[TILED_UNPACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1] // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1] // CHECK: } // CHECK: } // CHECK: return %[[FINAL_RESULT]]#1 : @@ -369,8 +369,71 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] -// CHECK: } +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] + +// ----- + +module { + func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) { + %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> + %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32> + %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + scf.yield %2 : tensor<256x256xf32> + } + %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %4 : tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_add_consumer_into_nested_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[dest1:.*]] = linalg.fill +// CHECK-SAME: outs(%[[dest0]] : +// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] +// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] +// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] +// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] +// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] +// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] +// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : +// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : +// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] +// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : +// CHECK: } +// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : // CHECK: } -// CHECK: return %[[FINAL_RESULT]]#1 : +// CHECK: return %[[LOOP_RESULT1]]#1 :