-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][scf] Extend consumer fusion to multiple tilable users #111955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (Yun-Fly) ChangesCurrently, consumer fusion expects single usage(or others are terminator op). This patch support multiple consumers fusion. I.e. fusing following input
into
The key process is ensuring(or adapting) that the To enable this feature in TEST mlir, we need to specify Full diff: https://github.com/llvm/llvm-project/pull/111955.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..a758db6c68cf81 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1585,26 +1585,27 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
/// failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
- // Check that the value has exactly one use which isn't a scf.yield or a
- // tensor.parallel_insert_slice op.
OpOperand *operand = nullptr;
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
- if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+ // Step 1. Check if the user is tilable.
+ if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now. Add
+ // support for other op such as op has InferTypeOpInterface.
continue;
- if (operand)
- return failure();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
- return failure();
- operand = &opOperand;
+ } else {
+ // Step 2. Check if user stay in the same block.
+ if (containingOpBlock != consumerOp->getBlock())
+ continue;
+ // Step 3. Check if user has succeeding user. Otherwise, it usually
+ // represents already tiled.
+ if (consumerOp->use_empty())
+ continue;
+ operand = &opOperand;
+ break;
+ }
}
-
if (operand)
return operand;
return failure();
@@ -1699,28 +1700,123 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we need to move the loop op right
+/// before a certain op in order to maintain a valid def-use chain. This utility
+/// thus helps ensuring that no invalid IR is formed. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+/// ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to move the loop op right before the `firstUserOfLoop`:
+///
+/// ```
+/// %0:2 = scf.for() {
+/// %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would double-check there is no user of
+/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop`
+/// after `lastDefOfConsumer`. Then, it turns out valid as follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+/// %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// Besides, `consumerOp` should not be the user of `firstUserOfLoop`.
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param toMoveLoopOpBefore: the operation we move the looOp right before
static LogicalResult checkAssumptionForLoop(Operation *loopOp,
- Operation *consumerOp) {
- // Check if the loop op yields one result.
- if (loopOp->getNumResults() == 1)
- return success();
- // Check if the consumerOp is the first user of the loopOp and if other users
- // are in the same containing block as that of consumer op's.
+ Operation *consumerOp,
+ Operation **toMoveLoopOpBefore) {
Block *parentBlock = consumerOp->getBlock();
- for (Operation *userOp : loopOp->getUsers()) {
- if (userOp == consumerOp)
- continue;
- if (parentBlock != userOp->getBlock() ||
- !consumerOp->isBeforeInBlock(userOp))
- return failure();
- }
+ // loopOp and consumerOp should stay in the same block.
+ if (loopOp->getBlock() != parentBlock)
+ return failure();
+
+ *toMoveLoopOpBefore = nullptr;
+ do {
+ Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
+ // Find the first user of loopOp
+ for (Operation *userOp : loopOp->getUsers()) {
+ if (userOp == consumerOp)
+ continue;
+ // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+ // block with any other types of operation. Thus, just redirecting to its
+ // parent `InParallelOp`.
+ if (isa<tensor::ParallelInsertSliceOp>(userOp))
+ userOp = userOp->getParentOfType<scf::InParallelOp>();
+
+ if (parentBlock != userOp->getBlock())
+ return failure();
+
+ if (userOp->isBeforeInBlock(firstUserOfLoop))
+ firstUserOfLoop = userOp;
+ }
+
+ // Find the last define of consumer
+ for (Value operand : consumerOp->getOperands()) {
+ // If the operand is `BlockArgument`, auto skip.
+ if (isa<BlockArgument>(operand))
+ continue;
+ auto defineOp = operand.getDefiningOp();
+ if (!defineOp)
+ return failure();
+ // If defineOp is not in the same block with loopOp, it must dominate the
+ // loopOp as well. I.e.
+ // ```
+ // %a = ...
+ // {
+ // %looOp = scf.for
+ // %b = consumerOp ins(%loopOp, %a)
+ // }
+ // ```
+ if (defineOp == loopOp || parentBlock != defineOp->getBlock())
+ continue;
+ if (lastDefOfConsumer->isBeforeInBlock(defineOp))
+ lastDefOfConsumer = defineOp;
+ }
+ if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
+ // Try to move if possible
+ if (llvm::all_of(firstUserOfLoop->getUsers(),
+ [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
+ return userOp->getBlock() == parentBlock &&
+ lastDefOfConsumer->isBeforeInBlock(userOp);
+ })) {
+ // Safely moving
+ firstUserOfLoop->moveAfter(lastDefOfConsumer);
+ } else {
+ return failure();
+ }
+ } else {
+ // Check consumerOp is not the user of firstUserOfLoop
+ if (firstUserOfLoop == lastDefOfConsumer)
+ return failure();
+ // Set InsertPoint
+ *toMoveLoopOpBefore = firstUserOfLoop;
+ }
+ } while (!(*toMoveLoopOpBefore));
+
return success();
}
@@ -1787,7 +1883,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
LoopLikeOpInterface outerMostLoop = nestedLoops.front();
- if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
+ // Find suitable insertPointOp to move the whole loop structure later.
+ Operation *toMoveLoopOpBefore = nullptr;
+ if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp,
+ &toMoveLoopOpBefore))) {
return rewriter.notifyMatchFailure(
outerMostLoop,
"containing loop op should either yield just one value or "
@@ -1812,9 +1911,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
Location loc = outerMostLoop->getLoc();
- // 3. Move the whole loop structure right before consumer Op, the dominance
+ // 3. Move the whole loop structure right before insertPoint, the dominance
// should be already ensured by `checkAssumptionForLoop`.
- rewriter.moveOpBefore(outerMostLoop, consumerOp);
+ rewriter.moveOpBefore(outerMostLoop, toMoveLoopOpBefore);
// 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f5f703d95e2d5b..af836d18e8c028 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -508,3 +508,65 @@ module {
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
// CHECK: }
// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
+
+// -----
+
+module {
+ func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+ %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+ %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+ %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+ scf.yield %insert_slice : tensor<256x256xf32>
+ }
+ %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
+// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
+// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
+// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+// CHECK: }
+// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index b6da47977cb4cf..5e903e378daf82 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -171,24 +171,27 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
template <typename Range>
static LogicalResult
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
- Range &&payloadOps, TransformResults &transformResults) {
+ Range &&payloadOps, uint32_t numConsumerToFuse,
+ TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);
- FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
- scf::tileAndFuseConsumerOfSlice(rewriter, target);
+ while (numConsumerToFuse--) {
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlice(rewriter, target);
- if (failed(fuseConsumerResults))
- return failure();
+ if (failed(fuseConsumerResults))
+ return failure();
- // Report back the relevant handles to the transform op.
- originalConsumerOps.push_back(
- fuseConsumerResults->origConsumerOperand->getOwner());
- fusedConsumerOps.push_back(
- fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ // Report back the relevant handles to the transform op.
+ originalConsumerOps.push_back(
+ fuseConsumerResults->origConsumerOperand->getOwner());
+ fusedConsumerOps.push_back(
+ fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ }
}
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
@@ -200,9 +203,9 @@ DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
- LogicalResult result =
- applyFuseConsumer(rewriter, getOperation(),
- state.getPayloadOps(getTarget()), transformResults);
+ LogicalResult result = applyFuseConsumer(
+ rewriter, getOperation(), state.getPayloadOps(getTarget()),
+ getNumConsumerToFuse(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa90..34b075a5c17f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -59,12 +59,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
}];
let arguments =
- (ins TransformHandleTypeInterface:$target);
+ (ins TransformHandleTypeInterface:$target,
+ DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer);
let assemblyFormat = [{
- $target attr-dict `:` functional-type(operands, results)
+ $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
+ attr-dict `:` functional-type(operands, results)
}];
}
|
Cc @pashu123 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you are trying to do here. But checking and bailing for this is very fragile. Essentially just based on the ordering of operations some fusion wont kick in, which is very hard to manage in general. Ideally the transformation itself should be able to move around operations to allow for the fusion to kick in.
The issue is that
... = firstUserOfLoop
... = defnOfConsumer
During transformation, the loop is moved just before the defnOfConsumer
(note that it is unnecessary to say "lastDefOfConsumer", there is only one def). and this movement causes use-def chain violation.
I think a more robust approach is to move the defOfConsumer
before firstUserOfLoop
. Its fairly simple to do by
- Check that
firstUserOfLoop
anddefnOfConsumer
are in the same block - Compute a backward slice of
defnOfConsumer
(with the op included) but cut the slice to not include any operation thatfirstUserOfLoop
already dominates. - The first user of loop should not be in the slice computed. If it does, then bail/look for next consumer.
- Move the slice computed (including
defnOfConsumer
beforefirstUserOfLoop
.
The check that is included in this change is pretty complicated, but we probably dont want such checks. (In any case I thought a simple dominance check would work).
Thanks for your time! This formalization is fairly constructive and makes the solution more clear. I have some question before reimplementation.
As for statement
I am sorry that I need to figure out what
There exists one corner case. As I have already commented in the current change, it is possible that one of defOfConsumer is outside the block of
where
Based on current consumer dispatching i.e. which one of multiple consumers would be fused. it would always grab the first user among the use list. So, could we resort the use list order by BTW, during the transformation, the order of operations generated before fusion maybe changed. I am afraid if developer have some strong assumption or expectation on these order. How about add another option like Looking forward your opinions, thanks! |
Ok, so thats what you mean by last defn of consumer...
I am not sure I follow this example. The
I'd push back against anything that relies on program ordering of operations (unless there are explicit barrier instructions etc. which we might not support for first implementation cause you might be really able to fuse across that anyway). Only valid ordering is use-def chain, but I am fine with having a flag that gates changes of instructions order.
|
I have refactored code based on backward slice, which is similar to this. Please help to take look again :) |
Ping |
Hey @Yun-Fly . I am traveling this week at the LLVM Dev meeting. Review bandwidth is limited. But I will review this early next week. Btw, if you are around, would be great to catch up |
Hi @MaheshRavishankar, Glad to know that this PR is still on you review list. Btw, I have received your greetings forwarded by my colleague. I also learnt a lot from community and your comments. Greatly appreciated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments. I seemed to have some comments that I hadnt clicked the green button on, sorry about that. But I think it is better to have a logic to move the operands into a separate function that can be called when needed. Interspersing that with existing methods make this hard to read.
c2e8c73
to
cf287cb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THis is really nice @Yun-Fly . I reviewed this and this is very clear and well done. Thanks!
Thanks for your review! |
Currently, consumer fusion expects single usage(or others are terminator op). This patch support multiple consumers fusion. I.e. fusing following input
into
The key process is ensuring(or adapting) that the
lastDefOfConsumer
must be before thefirstUserOfLoop
.To enable this feature in TEST mlir, we need to specify
num_consumer_to_fuse = 2
with default value equal to 1. On the other hand, as for downstream use, developers can customize the behavior. The fusing order of multiple consumers is decided by the use list order of the accordingValue
.