Skip to content

[mlir][scf] Extend consumer fusion to multiple tilable users #111955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024

Conversation

Yun-Fly
Copy link
Contributor

@Yun-Fly Yun-Fly commented Oct 11, 2024

Currently, consumer fusion expects single usage(or others are terminator op). This patch support multiple consumers fusion. I.e. fusing following input

%0 = scf.for {
  ...
  %p = tiledProducer
  ...
}
%1 = tilableOp1 ins(%0 :
%2 = tilableOp2 ins(%0 :

into

%0:3 = scf.for {
  ...
  %p = tiledProducer
  %1 = tilableOp1 ins(%p :
  %2 = tilableOp2 ins(%p :
  ...
}

The key process is ensuring(or adapting) that the lastDefOfConsumer must be before the firstUserOfLoop.

To enable this feature in TEST mlir, we need to specify num_consumer_to_fuse = 2 with default value equal to 1. On the other hand, as for downstream use, developers can customize the behavior. The fusing order of multiple consumers is decided by the use list order of the according Value.

@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (Yun-Fly)

Changes

Currently, consumer fusion expects single usage(or others are terminator op). This patch support multiple consumers fusion. I.e. fusing following input

%0 = scf.for {
  ...
  %p = tiledProducer
  ...
}
%1 = tilableOp1 ins(%0 :
%2 = tilableOp2 ins(%0 :

into

%0:3 = scf.for {
  ...
  %p = tiledProducer
  %1 = tilableOp1 ins(%p :
  %2 = tilableOp2 ins(%p :
  ...
}

The key process is ensuring(or adapting) that the lastDefOfConsumer must be before the firstUserOfLoop.

To enable this feature in TEST mlir, we need to specify num_consumer_to_fuse = 2 with default value equal to 1. On the other hand, as for downstream use, developers can customize the behavior. The fusing order of multiple consumers is decided by the use list order of the according Value.


Full diff: https://github.com/llvm/llvm-project/pull/111955.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+137-38)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+62)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+16-13)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+4-2)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e2feb10b314540..a758db6c68cf81 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1585,26 +1585,27 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
 /// failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Check that the value has exactly one use which isn't a scf.yield or a
-  // tensor.parallel_insert_slice op.
   OpOperand *operand = nullptr;
   for (OpOperand &opOperand : val.getUses()) {
     Operation *consumerOp = opOperand.getOwner();
-    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+    // Step 1. Check if the user is tilable.
+    if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
+      // TODO: We have to init result of consumer before scf.for, use
+      // DestinationStyleOpInterface to get result shape from init for now. Add
+      // support for other op such as op has InferTypeOpInterface.
       continue;
-    if (operand)
-      return failure();
-    // TODO: We have to init result of consumer before scf.for, use
-    //       DestinationStyleOpInterface to get result shape from init for now.
-    //       Add support for other op such as op has InferTypeOpInterface.
-    if (!isa<TilingInterface>(consumerOp) ||
-        !isa<DestinationStyleOpInterface>(consumerOp))
-      return failure();
-    if (containingOpBlock != consumerOp->getBlock())
-      return failure();
-    operand = &opOperand;
+    } else {
+      // Step 2. Check if user stay in the same block.
+      if (containingOpBlock != consumerOp->getBlock())
+        continue;
+      // Step 3. Check if user has succeeding user. Otherwise, it usually
+      // represents already tiled.
+      if (consumerOp->use_empty())
+        continue;
+      operand = &opOperand;
+      break;
+    }
   }
-
   if (operand)
     return operand;
   return failure();
@@ -1699,28 +1700,123 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
   return getConsumerFromUses(resultingValue, containingOp->getBlock());
 }
 
-/// This utility currently checks whether the loop either :-
-/// 1. Yields exactly one result.
-/// 2. Has consumer op as its first user and other users to be in the same
-/// containing block as that of consumer op's. Currently we clone the loop op
-/// right before the consumer op in order to maintain a valid def-use chain.
-/// This utility thus helps ensuring that no invalid IR is formed due to the
-/// same.
+/// This utility currently checks whether the first userOp of loop is NOT before
+/// the last defineOp of consumer. Currently we need to move the loop op right
+/// before a certain op in order to maintain a valid def-use chain. This utility
+/// thus helps ensuring that no invalid IR is formed. E.g.
+///
+/// ```
+/// %0 = scf.for() {
+///   ...
+/// }
+/// ...
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ...
+/// %3 = consumerOp(%2)
+/// ```
+///
+/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
+/// invalid to move the loop op right before the `firstUserOfLoop`:
+///
+/// ```
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ...
+/// %2 = lastDefOfConsumer
+/// ```
+///
+/// To address this issue, this utility would double-check there is no user of
+/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop`
+/// after `lastDefOfConsumer`. Then, it turns out valid as follow:
+///
+/// ```
+/// %2 = lastDefOfConsumer
+/// %0:2 = scf.for() {
+///    %3 = tiledConsumerOp(%2)
+/// }
+/// %1 = firstUserOfLoop(%0)
+/// ```
+///
+/// Besides, `consumerOp` should not be the user of `firstUserOfLoop`.
+///
+/// @param loopOp: loop operation
+/// @param consumerOp: consumer operation
+/// @param toMoveLoopOpBefore: the operation we move the looOp right before
 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
-                                            Operation *consumerOp) {
-  // Check if the loop op yields one result.
-  if (loopOp->getNumResults() == 1)
-    return success();
-  // Check if the consumerOp is the first user of the loopOp and if other users
-  // are in the same containing block as that of consumer op's.
+                                            Operation *consumerOp,
+                                            Operation **toMoveLoopOpBefore) {
   Block *parentBlock = consumerOp->getBlock();
-  for (Operation *userOp : loopOp->getUsers()) {
-    if (userOp == consumerOp)
-      continue;
-    if (parentBlock != userOp->getBlock() ||
-        !consumerOp->isBeforeInBlock(userOp))
-      return failure();
-  }
+  // loopOp and consumerOp should stay in the same block.
+  if (loopOp->getBlock() != parentBlock)
+    return failure();
+
+  *toMoveLoopOpBefore = nullptr;
+  do {
+    Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
+    // Find the first user of loopOp
+    for (Operation *userOp : loopOp->getUsers()) {
+      if (userOp == consumerOp)
+        continue;
+      // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
+      // block with any other types of operation. Thus, just redirecting to its
+      // parent `InParallelOp`.
+      if (isa<tensor::ParallelInsertSliceOp>(userOp))
+        userOp = userOp->getParentOfType<scf::InParallelOp>();
+
+      if (parentBlock != userOp->getBlock())
+        return failure();
+
+      if (userOp->isBeforeInBlock(firstUserOfLoop))
+        firstUserOfLoop = userOp;
+    }
+
+    // Find the last define of consumer
+    for (Value operand : consumerOp->getOperands()) {
+      // If the operand is `BlockArgument`, auto skip.
+      if (isa<BlockArgument>(operand))
+        continue;
+      auto defineOp = operand.getDefiningOp();
+      if (!defineOp)
+        return failure();
+      // If defineOp is not in the same block with loopOp, it must dominate the
+      // loopOp as well. I.e.
+      // ```
+      //  %a = ...
+      //  {
+      //     %looOp = scf.for
+      //     %b = consumerOp ins(%loopOp, %a)
+      //   }
+      // ```
+      if (defineOp == loopOp || parentBlock != defineOp->getBlock())
+        continue;
+      if (lastDefOfConsumer->isBeforeInBlock(defineOp))
+        lastDefOfConsumer = defineOp;
+    }
+    if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) {
+      // Try to move if possible
+      if (llvm::all_of(firstUserOfLoop->getUsers(),
+                       [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
+                         return userOp->getBlock() == parentBlock &&
+                                lastDefOfConsumer->isBeforeInBlock(userOp);
+                       })) {
+        // Safely moving
+        firstUserOfLoop->moveAfter(lastDefOfConsumer);
+      } else {
+        return failure();
+      }
+    } else {
+      // Check consumerOp is not the user of firstUserOfLoop
+      if (firstUserOfLoop == lastDefOfConsumer)
+        return failure();
+      // Set InsertPoint
+      *toMoveLoopOpBefore = firstUserOfLoop;
+    }
+  } while (!(*toMoveLoopOpBefore));
+
   return success();
 }
 
@@ -1787,7 +1883,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   LoopLikeOpInterface outerMostLoop = nestedLoops.front();
 
-  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
+  // Find suitable insertPointOp to move the whole loop structure later.
+  Operation *toMoveLoopOpBefore = nullptr;
+  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp,
+                                    &toMoveLoopOpBefore))) {
     return rewriter.notifyMatchFailure(
         outerMostLoop,
         "containing loop op should either yield just one value or "
@@ -1812,9 +1911,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   Location loc = outerMostLoop->getLoc();
 
-  // 3. Move the whole loop structure right before consumer Op, the dominance
+  // 3. Move the whole loop structure right before insertPoint, the dominance
   // should be already ensured by `checkAssumptionForLoop`.
-  rewriter.moveOpBefore(outerMostLoop, consumerOp);
+  rewriter.moveOpBefore(outerMostLoop, toMoveLoopOpBefore);
 
   // 4. Set insertion point before terminator op of the loop and create a new
   // tensor.insert_slice. In the scf.for case this is a clone of the
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index f5f703d95e2d5b..af836d18e8c028 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -508,3 +508,65 @@ module {
 //      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
 //      CHECK:   }
 //      CHECK:   return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
+
+// -----
+
+module {
+  func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
+        %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+        %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
+        %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+        scf.yield %insert_slice : tensor<256x256xf32>
+    }
+    %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @fuse_add_multiple_tilable_consumers(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) 
+// CHECK-SAME:   {
+//      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_ADD_OUT:.*]] = linalg.add
+// CHECK-SAME:                ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
+// CHECK-SAME:                outs(%[[ADD_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
+// CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
+//      CHECK:          %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[TILED_MUL_OUT:.*]] = linalg.mul
+// CHECK-SAME:                ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
+// CHECK-SAME:                outs(%[[MUL_OUT_SLICE]] :
+//      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
+//      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
+//      CHECK:   }
+//      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index b6da47977cb4cf..5e903e378daf82 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -171,24 +171,27 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
 template <typename Range>
 static LogicalResult
 applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
-                  Range &&payloadOps, TransformResults &transformResults) {
+                  Range &&payloadOps, uint32_t numConsumerToFuse,
+                  TransformResults &transformResults) {
   SmallVector<Operation *> originalConsumerOps;
   SmallVector<Operation *> fusedConsumerOps;
 
   for (Operation *target : payloadOps) {
     rewriter.setInsertionPoint(target);
 
-    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
-        scf::tileAndFuseConsumerOfSlice(rewriter, target);
+    while (numConsumerToFuse--) {
+      FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+          scf::tileAndFuseConsumerOfSlice(rewriter, target);
 
-    if (failed(fuseConsumerResults))
-      return failure();
+      if (failed(fuseConsumerResults))
+        return failure();
 
-    // Report back the relevant handles to the transform op.
-    originalConsumerOps.push_back(
-        fuseConsumerResults->origConsumerOperand->getOwner());
-    fusedConsumerOps.push_back(
-        fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+      // Report back the relevant handles to the transform op.
+      originalConsumerOps.push_back(
+          fuseConsumerResults->origConsumerOperand->getOwner());
+      fusedConsumerOps.push_back(
+          fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+    }
   }
 
   transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
@@ -200,9 +203,9 @@ DiagnosedSilenceableFailure
 transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                      TransformResults &transformResults,
                                      TransformState &state) {
-  LogicalResult result =
-      applyFuseConsumer(rewriter, getOperation(),
-                        state.getPayloadOps(getTarget()), transformResults);
+  LogicalResult result = applyFuseConsumer(
+      rewriter, getOperation(), state.getPayloadOps(getTarget()),
+      getNumConsumerToFuse(), transformResults);
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa90..34b075a5c17f9e 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -59,12 +59,14 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
   }];
 
   let arguments =
-    (ins TransformHandleTypeInterface:$target);
+    (ins TransformHandleTypeInterface:$target,
+        DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
   let results = (outs TransformHandleTypeInterface:$consumer,
                       TransformHandleTypeInterface:$fused_consumer);
 
   let assemblyFormat = [{
-    $target attr-dict `:` functional-type(operands, results)
+    $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? 
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 

@MaheshRavishankar
Copy link
Contributor

Cc @pashu123

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are trying to do here. But checking and bailing for this is very fragile. Essentially just based on the ordering of operations some fusion wont kick in, which is very hard to manage in general. Ideally the transformation itself should be able to move around operations to allow for the fusion to kick in.

The issue is that

... = firstUserOfLoop

... = defnOfConsumer

During transformation, the loop is moved just before the defnOfConsumer (note that it is unnecessary to say "lastDefOfConsumer", there is only one def). and this movement causes use-def chain violation.

I think a more robust approach is to move the defOfConsumer before firstUserOfLoop. Its fairly simple to do by

  1. Check that firstUserOfLoop and defnOfConsumer are in the same block
  2. Compute a backward slice of defnOfConsumer (with the op included) but cut the slice to not include any operation that firstUserOfLoop already dominates.
  3. The first user of loop should not be in the slice computed. If it does, then bail/look for next consumer.
  4. Move the slice computed (including defnOfConsumer before firstUserOfLoop.

The check that is included in this change is pretty complicated, but we probably dont want such checks. (In any case I thought a simple dominance check would work).

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Oct 14, 2024

Thanks for your time! This formalization is fairly constructive and makes the solution more clear. I have some question before reimplementation.

The issue is that

... = firstUserOfLoop

... = defnOfConsumer

During transformation, the loop is moved just before the defnOfConsumer (note that it is unnecessary to say "lastDefOfConsumer", there is only one def). and this movement causes use-def chain violation.

As for statement there is only one def, don't we need to check each defineOp of all operands of consumerOp? Please correct me if I have misunderstood..

  1. Compute a backward slice of defnOfConsumer (with the op included) but cut the slice to not include any operation that firstUserOfLoop already dominates.
  2. Move the slice computed (including defnOfConsumer before firstUserOfLoop.

I am sorry that I need to figure out what defnOfConsumer exactly represents here before grasping the details here.

  1. Check that firstUserOfLoop and defnOfConsumer are in the same block

There exists one corner case. As I have already commented in the current change, it is possible that one of defOfConsumer is outside the block of firstUserOfLoop. E.g.

%a = ...
{
%loopOp = scf.for
....
%b = consumerOp ins(%loopOp, %a)
}

where %a dominates the block of firstUserOfLoop ==> it must dominate firstUserOfLoop ==> we can skip subsequent check instead of bailing.

  1. The first user of loop should not be in the slice computed. If it does, then bail/look for next consumer.

Based on current consumer dispatching i.e. which one of multiple consumers would be fused. it would always grab the first user among the use list. So, could we resort the use list order by shuffleUseList to put the next consumer ahead?

BTW, during the transformation, the order of operations generated before fusion maybe changed. I am afraid if developer have some strong assumption or expectation on these order. How about add another option like bool intrusive(or any other name) to control this behavior. If the user does not want to modify any order caused by fusion, we just bail rather than move those def.

Looking forward your opinions, thanks!

@MaheshRavishankar
Copy link
Contributor

Thanks for your time! This formalization is fairly constructive and makes the solution more clear. I have some question before reimplementation.

The issue is that

... = firstUserOfLoop

... = defnOfConsumer

During transformation, the loop is moved just before the defnOfConsumer (note that it is unnecessary to say "lastDefOfConsumer", there is only one def). and this movement causes use-def chain violation.

As for statement there is only one def, don't we need to check each defineOp of all operands of consumerOp? Please correct me if I have misunderstood..

Ok, so thats what you mean by last defn of consumer... defnOfConsumer dominates all definitions of. You typically get it through program slice.

  1. Compute a backward slice of defnOfConsumer (with the op included) but cut the slice to not include any operation that firstUserOfLoop already dominates.
  2. Move the slice computed (including defnOfConsumer before firstUserOfLoop.

I am sorry that I need to figure out what defnOfConsumer exactly represents here before grasping the details here.

  1. Check that firstUserOfLoop and defnOfConsumer are in the same block

There exists one corner case. As I have already commented in the current change, it is possible that one of defOfConsumer is outside the block of firstUserOfLoop. E.g.

%a = ...
{
%loopOp = scf.for
....
%b = consumerOp ins(%loopOp, %a)
}

where %a dominates the block of firstUserOfLoop ==> it must dominate firstUserOfLoop ==> we can skip subsequent check instead of bailing.

I am not sure I follow this example. The %loopOps first use dominates the consumer. We dont have to have scf.for dominate all definitions of the operands of consumerOp, i.e we dont need the scf.for to dominate the entire slice of consumerOp. What I meant was that if you take a program slice of consumerOp which is cut-off at scf.for, the only use should be consumerOp.

  1. The first user of loop should not be in the slice computed. If it does, then bail/look for next consumer.

Based on current consumer dispatching i.e. which one of multiple consumers would be fused. it would always grab the first user among the use list. So, could we resort the use list order by shuffleUseList to put the next consumer ahead?

BTW, during the transformation, the order of operations generated before fusion maybe changed. I am afraid if developer have some strong assumption or expectation on these order. How about add another option like bool intrusive(or any other name) to control this behavior. If the user does not want to modify any order caused by fusion, we just bail rather than move those def.

I'd push back against anything that relies on program ordering of operations (unless there are explicit barrier instructions etc. which we might not support for first implementation cause you might be really able to fuse across that anyway). Only valid ordering is use-def chain, but I am fine with having a flag that gates changes of instructions order.

Looking forward your opinions, thanks!

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Oct 15, 2024

We dont have to have scf.for dominate all definitions of the operands of consumerOp, i.e we dont need the scf.for to dominate the entire slice of consumerOp. What I meant was that if you take a program slice of consumerOp which is cut-off at scf.for, the only use should be consumerOp.

I have refactored code based on backward slice, which is similar to this. Please help to take look again :)

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Oct 21, 2024

Ping

@MaheshRavishankar
Copy link
Contributor

Ping

Hey @Yun-Fly . I am traveling this week at the LLVM Dev meeting. Review bandwidth is limited. But I will review this early next week. Btw, if you are around, would be great to catch up

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Oct 24, 2024

Ping

Hey @Yun-Fly . I am traveling this week at the LLVM Dev meeting. Review bandwidth is limited. But I will review this early next week. Btw, if you are around, would be great to catch up

Hi @MaheshRavishankar, Glad to know that this PR is still on you review list. Btw, I have received your greetings forwarded by my colleague. I also learnt a lot from community and your comments. Greatly appreciated!

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments. I seemed to have some comments that I hadnt clicked the green button on, sorry about that. But I think it is better to have a logic to move the operands into a separate function that can be called when needed. Interspersing that with existing methods make this hard to read.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_multiple_consumers branch from c2e8c73 to cf287cb Compare October 29, 2024 03:30
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is really nice @Yun-Fly . I reviewed this and this is very clear and well done. Thanks!

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Nov 6, 2024

THis is really nice @Yun-Fly . I reviewed this and this is very clear and well done. Thanks!

Thanks for your review!

@Yun-Fly Yun-Fly merged commit 9bc3102 into llvm:main Nov 6, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants