diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index dac79111af3c9..781da1b4ef8a2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter, /// where `%0` had other uses as well. If not reconstructed from within the loop /// body, uses of `%0` could not be replaced, making it still live and the /// fusion immaterial. +/// +/// The @param `yieldResultNumber` decides which result would be yield. If not +/// given, yield all `opResult` of fused producer. LogicalResult yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, - MutableArrayRef loops); + MutableArrayRef loops, + ArrayRef yieldResultNumber = ArrayRef{}); /// Transformation information returned after tile and fuse. struct SCFTileAndFuseResult { diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 8865aba3b4ef0..3cd9c8ccce075 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> { For an operation to be "tiled and fused" with its (already tiled) consumer, an operation has to implement the following additional method (see description below): - - `generateResultTileValue + - `generateResultTileValue` + - `getIterationDomainTileFromResultTile` For an operation to be "tiled and fused" with its (already tiled) producer, an operation has to implement the following additional methods (see @@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Method to return the tile of the iteration domain based + on the given tile of the certain result. + + This method is required to allow operations to be "tiled and fused" + with an (already tiled) consumer. Given a tile of an result, + returns the tile of the iteration space that uses this tile. + - `resultNumber` is the result of the producer used by the consumer. + - `offsets` is the offset of the slice of the producer result used by + the tiled implementation of the consumer. + - `sizes` is the size of the slice of the producer result used by the + consumer. + If fusion of the producer with the consumer is not legal for the + result, or if this mapping cannot be computed, the implementation + should return a failure. + + For most cases `generateResultTileValue` could be a implemented using + `getIterationDomainTileFromResultTile` + `getTiledImplementation` + methods. + }], + /*retType=*/"::mlir::LogicalResult", + /*methodName=*/"getIterationDomainTileFromResultTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVectorImpl &":$iterDomainOffsets, + "SmallVectorImpl &":$iterDomainSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, InterfaceMethod< /*desc=*/[{ Generates the scalar implementation of the operation. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index c3ab3cecfada7..424f29e787215 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -215,10 +215,11 @@ struct LinalgOpTilingInterface return success(); } - FailureOr - generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { + LogicalResult getIterationDomainTileFromResultTile( + Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the output is a projected @@ -232,9 +233,21 @@ struct LinalgOpTilingInterface "unhandled tiled implementation generation when result is not " "accessed using a permuted projection"); } - SmallVector mappedOffsets, mappedSizes; + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, - mappedOffsets, mappedSizes); + iterDomainOffsets, iterDomainSizes); + return success(); + } + + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + SmallVector mappedOffsets, mappedSizes; + if (failed(getIterationDomainTileFromResultTile( + op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) { + return failure(); + } auto tilingInterfaceOp = cast(op); FailureOr tilingResult = tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a530117..2efa8149f52ba 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -940,49 +940,122 @@ mlir::scf::tileAndFuseProducerOfSlice( LogicalResult mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, - MutableArrayRef loops) { + MutableArrayRef loops, + ArrayRef yieldResultNumber) { if (loops.empty()) return success(); - OpResult fusableProducer = fusedProducerInfo.origProducer; - Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer; - FailureOr initValue = tensor::getOrCreateDestination( - rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); - if (succeeded(initValue)) { - - YieldTiledValuesFn newYieldValuesFn = - [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, - ValueRange newRegionIterArgs, SmallVector &tiledResult, - SmallVector> &tiledOffset, - SmallVector> &tiledSizes) - -> LogicalResult { - OpBuilder::InsertionGuard g(innerRewriter); - if (auto tiledDestStyleOp = - tiledAndFusedProducer - .getDefiningOp()) { - rewriter.setInsertionPoint(tiledDestStyleOp); - Value newRegionArg = newRegionIterArgs.back(); + Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), + *tiledOwner = fusedProducerInfo.tiledOps[0]; + + Location loc = originalOwner->getLoc(); + // a. collect all init Value to be appended + SmallVector initNumberList = + yieldResultNumber.empty() ? llvm::to_vector(llvm::seq( + 0, originalOwner->getNumResults())) + : llvm::to_vector(yieldResultNumber); + SmallVector initValueList; + for (const auto &resultNumber : initNumberList) { + FailureOr initValue = tensor::getOrCreateDestination( + rewriter, loc, originalOwner->getResult(resultNumber)); + if (succeeded(initValue)) { + initValueList.push_back(initValue.value()); + } else { + return failure(); + } + } + + YieldTiledValuesFn newYieldValuesFn = + [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, + ValueRange newRegionIterArgs, SmallVector &tiledResult, + SmallVector> &tiledOffset, + SmallVector> &tiledSizes) -> LogicalResult { + OpBuilder::InsertionGuard g(innerRewriter); + + // get sliceOp tile information + SmallVector sliceOffset = sliceOp.getMixedOffsets(), + sliceSizes = sliceOp.getMixedSizes(); + + // expect all strides of sliceOp being 1 + if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 1); + })) + return failure(); + + unsigned sliceResultNumber = + fusedProducerInfo.origProducer.getResultNumber(); + + auto tilableOp = cast(originalOwner); + // b. get iterDomain Offset and Sizes based on sliceOp tile + SmallVector iterDomainOffset, iterDomainSizes; + // skip tensor.pack/unpack/pad, which expects single opResult + if (tilableOp->getNumResults() > 1 && + failed(tilableOp.getIterationDomainTileFromResultTile( + rewriter, sliceResultNumber, sliceOffset, sliceSizes, + iterDomainOffset, iterDomainSizes))) { + // In theory, it is unnecessary to raise an error here. Actually although + // it fails to reconstruct the result tensor, it should not broke current + // fusion anyway. The reason why we must return failure currently is that + // the callback function `newYieldValuesFn` will be called after new init + // operand(s) has already been appended. It will take more refactoring to + // make sure the init operands are added consistently in the future. For + // more details, please refer to: + // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 + return failure(); + } + + // c. calculate offsets and sizes info of all OpResults respectively based + // on iteration Domain Tile + SmallVector> offsetList, sizesList; + for (const auto &resultNumber : initNumberList) { + if (resultNumber == sliceResultNumber) { + offsetList.push_back(sliceOffset); + sizesList.push_back(sliceSizes); + } else { + assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); + // infer result tile according to the iteration domain tile + SmallVector offset, sizes; + if (failed(tilableOp.getResultTilePosition( + rewriter, resultNumber, iterDomainOffset, iterDomainSizes, + offset, sizes))) { + return failure(); + } + offsetList.push_back(offset); + sizesList.push_back(sizes); + } + } + + // d. create `extract_slice` for `iter_args` for DPS operation if necessary + if (auto tiledDestStyleOp = + dyn_cast(tiledOwner)) { + rewriter.setInsertionPoint(tiledDestStyleOp); + for (const auto &&[index, newRegionArg] : + llvm::enumerate(newRegionIterArgs)) { auto destSlice = rewriter.create( - sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); - unsigned resultNumber = fusableProducer.getResultNumber(); + loc, newRegionArg, offsetList[index], sizesList[index], + SmallVector(offsetList[index].size(), + rewriter.getIndexAttr(1))); + unsigned resultNumber = initNumberList[index]; rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); }); } - Block *block = rewriter.getInsertionPoint()->getBlock(); - rewriter.setInsertionPoint(block->getTerminator()); - tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer); - tiledOffset.emplace_back(sliceOp.getMixedOffsets()); - tiledSizes.emplace_back(sliceOp.getMixedSizes()); - return success(); - }; + } - return addInitOperandsToLoopNest(rewriter, loops, - SmallVector{initValue.value()}, - newYieldValuesFn); - } - return success(); + // e. prepare tiled offset and sizes for later `insert_slice` creation by + // caller + Block *block = rewriter.getInsertionPoint()->getBlock(); + rewriter.setInsertionPoint(block->getTerminator()); + for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { + tiledResult.push_back(tiledOwner->getResult(resultNumber)); + tiledOffset.emplace_back(offsetList[index]); + tiledSizes.emplace_back(sizesList[index]); + } + return success(); + }; + + return addInitOperandsToLoopNest(rewriter, loops, initValueList, + newYieldValuesFn); } /// Implementation of tile consumer and fuse producer greedily. @@ -1072,14 +1145,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( continue; if (yieldReplacement) { + // Reconstruct and yield all opResult of fusableProducerOp by default. The + // caller can specific which one to yield by designating optional argument + // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. + Operation *fusableProducerOp = fusableProducer.getOwner(); if (failed(yieldReplacementForFusedProducer( rewriter, candidateSliceOp, fusedResult.value(), loops))) { return rewriter.notifyMatchFailure( - fusableProducer.getOwner(), "failed to replacement value for this " - "oepration from within the tiled loop"); + fusableProducerOp, "failed to replacement value for this " + "operation from within the tiled loop"); + } + for (auto [index, result] : + llvm::enumerate(fusableProducerOp->getResults())) { + origValToResultNumber[result] = loops.front()->getNumResults() - + fusableProducerOp->getNumResults() + + index; } - origValToResultNumber[fusableProducer] = - loops.front()->getNumResults() - 1; } if (Operation *tiledAndFusedOp = diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir index 7356c11e85ac0..3c0ada9d2cabc 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir @@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] // CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] // CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>, + %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, + %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) + -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) { + %out0, %out1 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j, i)>], + iterator_types = ["parallel", "parallel"] + } + ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>) + outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) { + ^bb0(%0: f32, %1: f32, %2: f32, %3: f32): + %4 = arith.mulf %0, %1 : f32 + %5 = arith.addf %0, %1 : f32 + linalg.yield %4, %5: f32, f32 + } -> (tensor<32x32xf32>, tensor<32x32xf32>) + + %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32> + + return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %add = transform.structured.match ops{["linalg.add"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_and_yield %add [16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @multiple_outputs_fusion_yield_all( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>) +// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]] +// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[ADD_TILE:.+]] = linalg.add +// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[INIT2_TILE]] : +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]] +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0