From 4a786b2a7adbe92197890bf3c60361dafec2f3ca Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Wed, 22 May 2024 23:15:36 -0700 Subject: [PATCH 1/5] yield replacement for multiple results --- .../SCF/Transforms/TileUsingInterface.h | 6 +- .../mlir/Interfaces/TilingInterface.td | 38 ++++- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 25 ++- .../SCF/Transforms/TileUsingInterface.cpp | 148 +++++++++++++----- .../tile-fuse-and-yield-using-interface.mlir | 62 ++++++++ 5 files changed, 233 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index dac79111af3c9..807379d99b599 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter, /// where `%0` had other uses as well. If not reconstructed from within the loop /// body, uses of `%0` could not be replaced, making it still live and the /// fusion immaterial. +/// +/// The @param `yieldResultNumber` decides which result would be yield. If not +/// given, yield all `opResult` of fused producer. LogicalResult yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, - MutableArrayRef loops); + MutableArrayRef loops, + std::optional> yieldResultNumber = std::nullopt); /// Transformation information returned after tile and fuse. struct SCFTileAndFuseResult { diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 8865aba3b4ef0..3cd9c8ccce075 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> { For an operation to be "tiled and fused" with its (already tiled) consumer, an operation has to implement the following additional method (see description below): - - `generateResultTileValue + - `generateResultTileValue` + - `getIterationDomainTileFromResultTile` For an operation to be "tiled and fused" with its (already tiled) producer, an operation has to implement the following additional methods (see @@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> { return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Method to return the tile of the iteration domain based + on the given tile of the certain result. + + This method is required to allow operations to be "tiled and fused" + with an (already tiled) consumer. Given a tile of an result, + returns the tile of the iteration space that uses this tile. + - `resultNumber` is the result of the producer used by the consumer. + - `offsets` is the offset of the slice of the producer result used by + the tiled implementation of the consumer. + - `sizes` is the size of the slice of the producer result used by the + consumer. + If fusion of the producer with the consumer is not legal for the + result, or if this mapping cannot be computed, the implementation + should return a failure. + + For most cases `generateResultTileValue` could be a implemented using + `getIterationDomainTileFromResultTile` + `getTiledImplementation` + methods. + }], + /*retType=*/"::mlir::LogicalResult", + /*methodName=*/"getIterationDomainTileFromResultTile", + /*args=*/(ins + "OpBuilder &":$b, + "unsigned":$resultNumber, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVectorImpl &":$iterDomainOffsets, + "SmallVectorImpl &":$iterDomainSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return failure(); + }] + >, InterfaceMethod< /*desc=*/[{ Generates the scalar implementation of the operation. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index c3ab3cecfada7..424f29e787215 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -215,10 +215,11 @@ struct LinalgOpTilingInterface return success(); } - FailureOr - generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { + LogicalResult getIterationDomainTileFromResultTile( + Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { auto linalgOp = cast(op); // Check that the indexing map used for the output is a projected @@ -232,9 +233,21 @@ struct LinalgOpTilingInterface "unhandled tiled implementation generation when result is not " "accessed using a permuted projection"); } - SmallVector mappedOffsets, mappedSizes; + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, - mappedOffsets, mappedSizes); + iterDomainOffsets, iterDomainSizes); + return success(); + } + + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + SmallVector mappedOffsets, mappedSizes; + if (failed(getIterationDomainTileFromResultTile( + op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) { + return failure(); + } auto tilingInterfaceOp = cast(op); FailureOr tilingResult = tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a530117..e8cf7298be681 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice( LogicalResult mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, - MutableArrayRef loops) { + MutableArrayRef loops, + std::optional> yieldResultNumber) { if (loops.empty()) return success(); - OpResult fusableProducer = fusedProducerInfo.origProducer; - Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer; - FailureOr initValue = tensor::getOrCreateDestination( - rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); - if (succeeded(initValue)) { - - YieldTiledValuesFn newYieldValuesFn = - [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, - ValueRange newRegionIterArgs, SmallVector &tiledResult, - SmallVector> &tiledOffset, - SmallVector> &tiledSizes) - -> LogicalResult { - OpBuilder::InsertionGuard g(innerRewriter); - if (auto tiledDestStyleOp = - tiledAndFusedProducer - .getDefiningOp()) { - rewriter.setInsertionPoint(tiledDestStyleOp); - Value newRegionArg = newRegionIterArgs.back(); + Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), + *tiledOwner = fusedProducerInfo.tiledOps[0]; + + Location loc = originalOwner->getLoc(); + // a. collect all init Value to be appended + ArrayRef initNumberList = + yieldResultNumber ? yieldResultNumber.value() + : llvm::to_vector(llvm::seq( + 0, originalOwner->getNumResults())); + SmallVector initValueList; + for (const auto &resultNumber : initNumberList) { + FailureOr initValue = tensor::getOrCreateDestination( + rewriter, loc, originalOwner->getResult(resultNumber)); + if (succeeded(initValue)) { + initValueList.push_back(initValue.value()); + } else { + return failure(); + } + } + + YieldTiledValuesFn newYieldValuesFn = + [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, + ValueRange newRegionIterArgs, SmallVector &tiledResult, + SmallVector> &tiledOffset, + SmallVector> &tiledSizes) -> LogicalResult { + OpBuilder::InsertionGuard g(innerRewriter); + + // get sliceOp tile information + SmallVector sliceOffset = sliceOp.getMixedOffsets(), + sliceSizes = sliceOp.getMixedSizes(); + + // expect all strides of sliceOp being 1 + if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 1); + })) + return failure(); + + unsigned sliceResultNumber = + fusedProducerInfo.origProducer.getResultNumber(); + + auto tilableOp = cast(originalOwner); + // b. get iterDomain Offset and Sizes based on sliceOp tile + SmallVector iterDomainOffset, iterDomainSizes; + // skip tensor.pack/unpack/pad, which expects single opResult + if (tilableOp->getNumResults() > 1 && + failed(tilableOp.getIterationDomainTileFromResultTile( + rewriter, sliceResultNumber, sliceOffset, sliceSizes, + iterDomainOffset, iterDomainSizes))) { + return failure(); + } + + // c. calculate offsets and sizes info of all OpResults respectively based + // on iteration Domain Tile + SmallVector> offsetList, sizesList; + for (const auto &resultNumber : initNumberList) { + if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) { + offsetList.push_back(sliceOffset); + sizesList.push_back(sliceSizes); + } else { + assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); + // infer result tile according to the iteration domain tile + SmallVector offset, sizes; + if (failed(tilableOp.getResultTilePosition( + rewriter, resultNumber, iterDomainOffset, iterDomainSizes, + offset, sizes))) { + return failure(); + } + offsetList.push_back(offset); + sizesList.push_back(sizes); + } + } + + // d. create `extract_slice` for `iter_args` for DPS operation if necessary + if (auto tiledDestStyleOp = + dyn_cast(tiledOwner)) { + rewriter.setInsertionPoint(tiledDestStyleOp); + for (const auto &&[index, newRegionArg] : + llvm::enumerate(newRegionIterArgs)) { auto destSlice = rewriter.create( - sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); - unsigned resultNumber = fusableProducer.getResultNumber(); + loc, newRegionArg, offsetList[index], sizesList[index], + SmallVector(offsetList[index].size(), + rewriter.getIndexAttr(1))); + unsigned resultNumber = initNumberList[index]; rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); }); } - Block *block = rewriter.getInsertionPoint()->getBlock(); - rewriter.setInsertionPoint(block->getTerminator()); - tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer); - tiledOffset.emplace_back(sliceOp.getMixedOffsets()); - tiledSizes.emplace_back(sliceOp.getMixedSizes()); - return success(); - }; + } - return addInitOperandsToLoopNest(rewriter, loops, - SmallVector{initValue.value()}, - newYieldValuesFn); - } - return success(); + // e. prepare tiled offset and sizes for later `insert_slice` creation by + // caller + Block *block = rewriter.getInsertionPoint()->getBlock(); + rewriter.setInsertionPoint(block->getTerminator()); + for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { + tiledResult.push_back(tiledOwner->getResult(resultNumber)); + tiledOffset.emplace_back(offsetList[index]); + tiledSizes.emplace_back(sizesList[index]); + } + return success(); + }; + + return addInitOperandsToLoopNest(rewriter, loops, initValueList, + newYieldValuesFn); } /// Implementation of tile consumer and fuse producer greedily. @@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( continue; if (yieldReplacement) { + // Reconstruct and yield all opResult of fusableProducerOp by default. The + // caller can specific which one to yield by designating optional argument + // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. + Operation *fusableProducerOp = fusableProducer.getOwner(); if (failed(yieldReplacementForFusedProducer( rewriter, candidateSliceOp, fusedResult.value(), loops))) { return rewriter.notifyMatchFailure( - fusableProducer.getOwner(), "failed to replacement value for this " - "oepration from within the tiled loop"); + fusableProducerOp, "failed to replacement value for this " + "operation from within the tiled loop"); + } + for (const auto &result : fusableProducerOp->getResults()) { + origValToResultNumber[result] = + loops.front()->getNumResults() - + (fusableProducerOp->getNumResults() - result.getResultNumber()); } - origValToResultNumber[fusableProducer] = - loops.front()->getNumResults() - 1; } if (Operation *tiledAndFusedOp = diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir index 7356c11e85ac0..3c0ada9d2cabc 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir @@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] // CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] // CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>, + %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, + %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) + -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) { + %out0, %out1 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j, i)>], + iterator_types = ["parallel", "parallel"] + } + ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>) + outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) { + ^bb0(%0: f32, %1: f32, %2: f32, %3: f32): + %4 = arith.mulf %0, %1 : f32 + %5 = arith.addf %0, %1 : f32 + linalg.yield %4, %5: f32, f32 + } -> (tensor<32x32xf32>, tensor<32x32xf32>) + + %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32> + + return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %add = transform.structured.match ops{["linalg.add"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_and_yield %add [16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @multiple_outputs_fusion_yield_all( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>, +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>) +// CHECK: %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]] +// CHECK: %[[GENERIC_TILE:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[ADD_TILE:.+]] = linalg.add +// CHECK-SAME: ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[INIT2_TILE]] : +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]] +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0 From b9450f784ec5a56e3e000bc4879321f2fed4b260 Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Wed, 5 Jun 2024 07:00:01 -0700 Subject: [PATCH 2/5] change default arguments --- .../mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 2 +- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 807379d99b599..781da1b4ef8a2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -198,7 +198,7 @@ LogicalResult yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops, - std::optional> yieldResultNumber = std::nullopt); + ArrayRef yieldResultNumber = ArrayRef{}); /// Transformation information returned after tile and fuse. struct SCFTileAndFuseResult { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e8cf7298be681..33142e61750d2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -941,7 +941,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops, - std::optional> yieldResultNumber) { + ArrayRef yieldResultNumber) { if (loops.empty()) return success(); @@ -951,9 +951,9 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( Location loc = originalOwner->getLoc(); // a. collect all init Value to be appended ArrayRef initNumberList = - yieldResultNumber ? yieldResultNumber.value() - : llvm::to_vector(llvm::seq( - 0, originalOwner->getNumResults())); + yieldResultNumber.empty() ? llvm::to_vector(llvm::seq( + 0, originalOwner->getNumResults())) + : yieldResultNumber; SmallVector initValueList; for (const auto &resultNumber : initNumberList) { FailureOr initValue = tensor::getOrCreateDestination( From eda4bf35b0535cc248b8555d378176a86831df0e Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Wed, 5 Jun 2024 18:47:54 -0700 Subject: [PATCH 3/5] fix CI --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 33142e61750d2..a6677393c73dc 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -950,10 +950,10 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( Location loc = originalOwner->getLoc(); // a. collect all init Value to be appended - ArrayRef initNumberList = + SmallVector initNumberList = yieldResultNumber.empty() ? llvm::to_vector(llvm::seq( 0, originalOwner->getNumResults())) - : yieldResultNumber; + : llvm::to_vector(yieldResultNumber); SmallVector initValueList; for (const auto &resultNumber : initNumberList) { FailureOr initValue = tensor::getOrCreateDestination( @@ -1000,7 +1000,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( // on iteration Domain Tile SmallVector> offsetList, sizesList; for (const auto &resultNumber : initNumberList) { - if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) { + if (resultNumber == sliceResultNumber) { offsetList.push_back(sliceOffset); sizesList.push_back(sliceSizes); } else { From 46e3751412ba00281912cb41965b109bc4521ee8 Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Thu, 20 Jun 2024 00:18:51 -0700 Subject: [PATCH 4/5] fix comment --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index a6677393c73dc..dab66ce97a6b5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1147,10 +1147,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } - for (const auto &result : fusableProducerOp->getResults()) { - origValToResultNumber[result] = - loops.front()->getNumResults() - - (fusableProducerOp->getNumResults() - result.getResultNumber()); + for (auto [index, result] : + llvm::enumerate(fusableProducerOp->getResults())) { + origValToResultNumber[result] = loops.front()->getNumResults() - + fusableProducerOp->getNumResults() + + index; } } From 1f0cbdbaf9cda930030fd2a5b9b6feba1b2840d6 Mon Sep 17 00:00:00 2001 From: "Song, Yunfei" Date: Thu, 27 Jun 2024 22:58:03 -0700 Subject: [PATCH 5/5] add a comment on why return failure --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index dab66ce97a6b5..2efa8149f52ba 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -993,6 +993,14 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( failed(tilableOp.getIterationDomainTileFromResultTile( rewriter, sliceResultNumber, sliceOffset, sliceSizes, iterDomainOffset, iterDomainSizes))) { + // In theory, it is unnecessary to raise an error here. Actually although + // it fails to reconstruct the result tensor, it should not broke current + // fusion anyway. The reason why we must return failure currently is that + // the callback function `newYieldValuesFn` will be called after new init + // operand(s) has already been appended. It will take more refactoring to + // make sure the init operands are added consistently in the future. For + // more details, please refer to: + // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 return failure(); }