Skip to content

[mlir][scf] Extend consumer fusion to multiple tilable users #111955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 181 additions & 66 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
Expand Down Expand Up @@ -1580,33 +1582,163 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
return success();
}

/// Fetches the OpOperand of the only user (and use) of the value `val` which
/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
/// failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
// Check that the value has exactly one use which isn't a scf.yield or a
// tensor.parallel_insert_slice op.
OpOperand *operand = nullptr;
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
continue;
if (operand)
return failure();
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now.
// Add support for other op such as op has InferTypeOpInterface.
if (!isa<TilingInterface>(consumerOp) ||
!isa<DestinationStyleOpInterface>(consumerOp))
/// An utility to get the first user of the given loopOp. If any of user stay in
/// different block of loopOp, return failure.
static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
Operation *firstUserOfLoop = nullptr;
for (Operation *userOp : loopOp->getUsers()) {
// `ParallelInsertSlice` located inside `InParallelOp` has no same parent
// block with any other types of operation. Thus, just redirecting to its
// parent `InParallelOp`. E.g.
//
// ```
// %1 = scf.for {
// ...
// }
// %2 = consumerOp ins(%1, ...)
// scf.forall.in_parallel {
// tensor.parallel_insert_slice %1
// }
// ```
// where `InParallelOp` but not `ParallelInsertSlice` stays in the same
// same block with `consumerOp`.
if (isa<tensor::ParallelInsertSliceOp>(userOp))
userOp = userOp->getParentOfType<scf::InParallelOp>();

if (loopOp->getBlock() != userOp->getBlock())
return failure();
if (containingOpBlock != consumerOp->getBlock())

if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
firstUserOfLoop = userOp;
}
return firstUserOfLoop;
}

/// This utility currently checks whether the first userOp of loop is NOT before
/// the last defineOp of consumer operand. Because that we need to move the
/// whole loop structure right before the `firstUserOfLoop`. This utility thus
/// helps ensuring that no invalid IR is formed, i.e. no backward slice of
/// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
///
/// ```
/// %0 = scf.for() {
/// ...
/// }
/// ...
/// %1 = firstUserOfLoop(%0)
/// ...
/// %2 = lastDefOfConsumerOperand
/// ...
/// %3 = consumerOp(%2)
/// ```
///
/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
/// use-def chain violation:
///
/// ```
/// %0:2 = scf.for() {
/// // use before define error
/// %3 = tiledConsumerOp(%2)
/// }
/// %1 = firstUserOfLoop(%0)
/// ...
/// %2 = lastDefOfConsumerOperand
/// ```
///
/// @param loopOp: loop operation
/// @param consumerOp: consumer operation
/// @param reorderOperations: the flag controls whether to reorder the backward
/// slice w.r.t. the defineOp of `consumerOp` operands.
/// @return: computed backward slice of consumerOp, but excluding those already
/// dominates `firstUserOfLoop`.
static FailureOr<llvm::SetVector<Operation *>>
checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
bool reorderOperations) {
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
if (failed(firstUserOfLoop))
return failure();

BackwardSliceOptions options;
DominanceInfo dominanceInfo;
options.inclusive = true;
options.omitBlockArguments = true;
bool includeLoopOp = false;
options.filter = [&](Operation *op) {
if (op == loopOp) {
includeLoopOp = true;
return false;
}
// Cut off the slice to not include any operation that already dominates
// firstUserOfLoop.
return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
};
llvm::SetVector<Operation *> slice;
for (auto operand : consumerOp->getOperands()) {
getBackwardSlice(operand, &slice, options);
}

if (!slice.empty()) {
// If consumerOp has one producer, which is also the user of loopOp.
// E.g.
// ```
// %0 = %loopOp
// %1 = consumerOp1 ins(%0)
// %2 = consumerOp2 ins(%0, %1)
// ```
// We can not fuse consumerOp2 into loopOp due to UD chain, unless
// consumerOp1 has already been fused into loopOp before.
if (includeLoopOp || !reorderOperations)
return failure();
operand = &opOperand;
}

if (operand)
return operand;
return slice;
}

/// Fetches the OpOperand of the first valid user (and use) of the value `val`
/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
/// Returns failure otherwise.
static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
Operation *loopOp,
unsigned resultNumber) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
Value val = loopOp->getResult(resultNumber);
Block *loopBlock = loopOp->getBlock();
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
// Step 1. Check if the user is tilable.
if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now. Add
// support for other op such as op has InferTypeOpInterface.
continue;
}
// Step 2. Check if user stay in the same block.
if (loopBlock != consumerOp->getBlock())
continue;
// Step 3. Check if user has succeeding user. Otherwise, it usually
// represents already tiled.
if (consumerOp->use_empty())
continue;
// Step 4. Check assumption for loop with `reorderOperations` enabled.
FailureOr<llvm::SetVector<Operation *>> slice =
checkAssumptionForLoop(loopOp, consumerOp, true);
if (failed(slice))
continue;
// Step 5. If backward sice is not empty, move them before firstUserOfLoop.
if (!slice->empty()) {
mlir::topologicalSort(*slice);
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
for (auto op : *slice) {
rewriter.moveOpBefore(op, *firstUserOfLoop);
}
}
return &opOperand;
}
return failure();
}

Expand Down Expand Up @@ -1659,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::InsertSliceOp candidateSliceOp) {
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
Expand All @@ -1672,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
if (!forOp)
return failure();
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
Value resultingValue = topLevelForOp->getResult(resultNumber);

return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}

/// Fetch the first untiled consumer of a scf.forall's result which is yielded
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::ParallelInsertSliceOp candidateSliceOp) {
// Step 1. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
Expand All @@ -1693,45 +1826,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
if (!forallOp)
return failure();
Value resultingValue =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));

return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();

/// This utility currently checks whether the loop either :-
/// 1. Yields exactly one result.
/// 2. Has consumer op as its first user and other users to be in the same
/// containing block as that of consumer op's. Currently we clone the loop op
/// right before the consumer op in order to maintain a valid def-use chain.
/// This utility thus helps ensuring that no invalid IR is formed due to the
/// same.
static LogicalResult checkAssumptionForLoop(Operation *loopOp,
Operation *consumerOp) {
// Check if the loop op yields one result.
if (loopOp->getNumResults() == 1)
return success();
// Check if the consumerOp is the first user of the loopOp and if other users
// are in the same containing block as that of consumer op's.
Block *parentBlock = consumerOp->getBlock();
for (Operation *userOp : loopOp->getUsers()) {
if (userOp == consumerOp)
continue;
if (parentBlock != userOp->getBlock() ||
!consumerOp->isBeforeInBlock(userOp))
return failure();
}
return success();
return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
}

/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(insertSlice);
return getUntiledConsumerFromSlice(rewriter, insertSlice);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(parallelInsertSlice);
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
} else {
return failure();
}
Expand All @@ -1751,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
getUntiledConsumerFromSlice(candidateSliceOp);
getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
Expand Down Expand Up @@ -1787,11 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,

LoopLikeOpInterface outerMostLoop = nestedLoops.front();

if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
return rewriter.notifyMatchFailure(
outerMostLoop,
"containing loop op should either yield just one value or "
"have the consumer op as its first user");
outerMostLoop, "the first user of loop should not dominate any define "
"of consumer operand(s)");
}

OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -1812,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,

Location loc = outerMostLoop->getLoc();

// 3. Move the whole loop structure right before consumer Op, the dominance
// should be already ensured by `checkAssumptionForLoop`.
rewriter.moveOpBefore(outerMostLoop, consumerOp);
// 3. Move the whole loop structure right before firstUserOfLoop, the
// dominance should be already ensured by `checkAssumptionForLoop`.
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
if (failed(firstUserOfLoop)) {
return rewriter.notifyMatchFailure(
outerMostLoop, "could not find the first user of outer most loop");
}
rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);

// 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
Expand Down
62 changes: 62 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,65 @@ module {
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
// CHECK: }
// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :

// -----

module {
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c256 = arith.constant 256 : index
%cst = arith.constant 0.000000e+00 : f32
%dest0 = tensor.empty() : tensor<256x256xf32>
%1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
%extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
%extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
%extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
%3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
%insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
scf.yield %insert_slice : tensor<256x256xf32>
}
%4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
%5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK: func.func @fuse_add_multiple_tilable_consumers(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
// CHECK: }
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
Loading
Loading