Skip to content

Implemented tiling and fusion path for GPU #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024

Conversation

AndreyPavlenko
Copy link
Contributor

@AndreyPavlenko AndreyPavlenko commented Oct 17, 2024

This path creates 2 nested loops for linalg operations, that are later converted to gpu.launch.
The outer loop is mapped to to the grid sizes and the inner loop is mapped to the block sizes.
The tiles calculation is based on the device information retrieved either from the module DLTI attributes or from the path options.

Depends on #406

@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 3 times, most recently from 04326d8 to dbaea96 Compare October 24, 2024 02:45
@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 7 times, most recently from 7656d44 to db5fbc1 Compare October 31, 2024 14:25
@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 3 times, most recently from 7466a14 to acc2dec Compare November 11, 2024 16:31
@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 2 times, most recently from 8faf654 to def5e26 Compare November 12, 2024 22:55
@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 4 times, most recently from 3e29f3b to 3597a30 Compare November 15, 2024 21:40
@dchigarev
Copy link
Contributor

There are currently two tests failing:

 1. GRAPH_COMPILER_OPT :: gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
                  Error: not enough SLM
 2. GRAPH_COMPILER_OPT :: gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096_transpose.mlir
                  Error: Assertion `Index < Length && "Invalid index!"' failed

The first one can be fixed by reducing the number of work group down to 16 (a hacky one, but will work for now until we figure out a proper fix):

diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
index cb3f5972..4e0f1265 100644
--- a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
+++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
@@ -1,6 +1,6 @@
 // RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s
 
-module {
+module attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"GPU" : #dlti.target_device_spec<#dlti.dl_entry<"max_work_group_size", 16 : i64>>>} {
   func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>,
                         %arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) {
     %cst = arith.constant 0.000000e+00 : f16

The second one fails because one of the inputs of the second linang.matmul appears to be SLM (and I haven't supported this case in my PR #409). The reason linalg.matmul takes SLM as an input is that two dependent matmuls were placed into a single GPU kernel, so that the second one shares its input buffer with the output buffer (SLM) of the first one:

f16_mlp_32x4096x4096x4096_transpose.mlir with NEW tiling

Both matmuls are placed into a single gpu kernel:

func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>) {
  %c128 = arith.constant 128 : index
  %c4096 = arith.constant 4096 : index
  %c8 = arith.constant 8 : index
  %c32 = arith.constant 32 : index
  %c2 = arith.constant 2 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.000000e+00 : f16
  %0 = memref.get_global @__constant_8x4096xf16 : memref<8x4096xf16>
  %1 = memref.get_global @__constant_8x128xf16 : memref<8x128xf16>
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c1, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c2, %arg15 = %c32, %arg16 = %c1) {
    %2 = affine.apply #map(%arg5)
    %subview_0 = memref.subview %alloc[%2, 0] [16, 4096] [1, 1] : memref<32x4096xf16> to memref<16x4096xf16, strided<[4096, 1], offset: ?>>
    %3 = affine.apply #map1(%arg8)
    %4 = affine.apply #map2(%arg9)
    %5 = arith.addi %3, %2 : index
    %subview_1 = memref.subview %arg4[%5, %4] [8, 128] [1, 1] : memref<32x4096xf16> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
    %subview_2 = memref.subview %arg2[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
    %subview_3 = memref.subview %arg0[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
    %alloc_4 = memref.alloc() : memref<16x131072xf16, 3>
    %6 = arith.muli %arg8, %c8 : index
    %7 = arith.muli %arg9, %c4096 : index
    %subview_5 = memref.subview %alloc_4[%6, %7] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
    linalg.fill ins(%cst : f16) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
    linalg.matmul_transpose_b ins(%subview_3, %arg1 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<4096x4096xf16>) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
    %alloc_6 = memref.alloc() : memref<16x131072xf16, 3>
    %8 = arith.muli %arg8, %c8 : index
    %9 = arith.muli %arg9, %c4096 : index
    %subview_7 = memref.subview %alloc_6[%8, %9] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
    linalg.add ins(%subview_2, %subview_5 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_7 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
    %alloc_8 = memref.alloc() : memref<16x131072xf16, 3>
    %10 = arith.muli %arg8, %c8 : index
    %11 = arith.muli %arg9, %c4096 : index
    %subview_9 = memref.subview %alloc_8[%10, %11] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
    linalg.max ins(%0, %subview_7 : memref<8x4096xf16>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_9 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
    %subview_10 = memref.subview %arg3[%4, 0] [128, 4096] [1, 1] : memref<4096x4096xf16> to memref<128x4096xf16, strided<[4096, 1], offset: ?>>
    %alloc_11 = memref.alloc() : memref<16x4096xf16, 3>
    %12 = arith.muli %arg8, %c8 : index
    %13 = arith.muli %arg9, %c128 : index
    %subview_12 = memref.subview %alloc_11[%12, %13] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
    linalg.fill ins(%cst : f16) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
    linalg.matmul_transpose_b ins(%subview_9, %subview_10 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>, memref<128x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
    %alloc_13 = memref.alloc() : memref<16x4096xf16, 3>
    %14 = arith.muli %arg8, %c8 : index
    %15 = arith.muli %arg9, %c128 : index
    %subview_14 = memref.subview %alloc_13[%14, %15] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
    linalg.add ins(%subview_1, %subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_14 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
    %subview_15 = memref.subview %subview_0[%3, %4] [8, 128] [1, 1] : memref<16x4096xf16, strided<[4096, 1], offset: ?>> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
    linalg.max ins(%1, %subview_14 : memref<8x128xf16>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_15 : memref<8x128xf16, strided<[4096, 1], offset: ?>>)
    gpu.terminator
  } {SCFToGPU_visited}
  %subview = memref.subview %alloc[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
  %cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
  call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
  memref.dealloc %alloc : memref<32x4096xf16>
  return
}
f16_mlp_32x4096x4096x4096_transpose.mlir with OLD tiling

Two dependent matmuls are placed into separate kernels

func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>, %arg5: memref<i8>) {
  %c32 = arith.constant 32 : index
  %c4096 = arith.constant 4096 : index
  %c128 = arith.constant 128 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.000000e+00 : f16
  %0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16>
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
    %1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
    %subview_5 = memref.subview %arg1[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
    %subview_6 = memref.subview %alloc[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.fill ins(%cst : f16) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    linalg.matmul_transpose_b ins(%arg0, %subview_5 : memref<32x4096xf16>, memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    %subview_7 = memref.subview %arg2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    %subview_8 = memref.subview %alloc_0[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.add ins(%subview_7, %subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    %subview_9 = memref.subview %alloc_1[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.max ins(%0, %subview_8 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_9 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    gpu.terminator
  } {SCFToGPU_visited}
  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
  gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
    %1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
    %subview_5 = memref.subview %arg3[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
    %alloc_6 = memref.alloc() : memref<4096x32xf16, 3>
    %2 = arith.muli %arg9, %c4096 : index
    %3 = arith.muli %arg10, %c32 : index
    %subview_7 = memref.subview %alloc_6[%2, %3] [4096, 32] [1, 1] : memref<4096x32xf16, 3> to memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>
    linalg.transpose ins(%subview_5 : memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_7 : memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) permutation = [1, 0] 
    %subview_8 = memref.subview %alloc_2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.fill ins(%cst : f16) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    linalg.matmul ins(%alloc_1, %subview_7 : memref<32x4096xf16>, memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    %subview_9 = memref.subview %arg4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    %subview_10 = memref.subview %alloc_3[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.add ins(%subview_9, %subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_10 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    %subview_11 = memref.subview %alloc_4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
    linalg.max ins(%0, %subview_10 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_11 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
    gpu.terminator
  } {SCFToGPU_visited}
  %subview = memref.subview %alloc_4[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
  %cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
  call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
  memref.dealloc %alloc : memref<32x4096xf16>
  memref.dealloc %alloc_0 : memref<32x4096xf16>
  memref.dealloc %alloc_1 : memref<32x4096xf16>
  memref.dealloc %alloc_2 : memref<32x4096xf16>
  memref.dealloc %alloc_3 : memref<32x4096xf16>
  memref.dealloc %alloc_4 : memref<32x4096xf16>
  return
}

The non-transposed case (f16_mlp_32x4096x4096x4096.mlir) works absolutely fine since two dependent matmuls are split into two separate kernels as expected, it's only the transposed one that causes the problem

@AndreyPavlenko
Copy link
Contributor Author

Now the matmuls are split into 2 kernels and boths tests fail due to SLM size exceeds target limits.

@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 2 times, most recently from 522f3e7 to bddcd97 Compare November 18, 2024 18:15
@AndreyPavlenko
Copy link
Contributor Author

The error SLM size exceeds target limits is fixed in #406 by reducing the workgroup size on the kernel compilation failure.

@AndreyPavlenko AndreyPavlenko marked this pull request as ready for review November 18, 2024 18:34

OpRewriter rw(fn);
tileAndFuseLinalgOps(rw, fn);
tileForallOps(rw, fn);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why separate the tiling and fusion into two phases and the second phase seems to have no fusion inside from the naming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second phase just creates a nested loop, there is nothing to fuse, it just tiles the outer loop.

Copy link
Member

@zhczhong zhczhong Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I got it. Then why do you tile the outer loop manually instead of calling the tileandFuse twice?

linalg.op

tileAndFuse linalg.op:
scf.forall {
  linalg.op
}

tileAndFuse linalg.op:
scf.forall {
  scf.forall {
    linalg.op
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we need to tile a single loop just to spread computations over the work-groups. We don't need to tile and fuse all the operations inside the loop. Actually, I tried it in the first approach, it worked for simple cases, like in your example, but for some more complex cases the result was not as it was expected.

Copy link
Member

@zhczhong zhczhong Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do it even there are multiple linalg.op as this is the upstream recommended way. Could you show me this complex case that the result is not expected calling the tileAndFuse twice?

linalg.op
linalg.op
linalg.opx

tileConsumerAndFuseProducersUsingSCF  linalg.opx:
scf.forall {
  linalg.op
  linalg.op
  linalg.opx
}

tileConsumerAndFuseProducersUsingSCF  linalg.opx:
scf.forall {
  scf.forall {
    linalg.op
    linalg.op
    linalg.opx
  }
}

Copy link
Contributor Author

@AndreyPavlenko AndreyPavlenko Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, when after the first tiling dynamic shapes appear. The second tiling will produce a different result. Besides, the single loop tiling requires much less computation, than full tiling and fusion of all the ops inside the loop.

Copy link
Member

@zhczhong zhczhong Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean by different result? As far as I know, the tileConsumerAndFuseProducersUsingSCF will insert affine.min, affine.max, affine.apply to ensure the tiling bound is valid if the dynamic shape exists. The extra code is for safety. There should be no difference in the fusion as the fusion only cares about the slice def-use chain.

If you want to tile scf.forall only, you could follow the scf-parallel-loop-tiling to do it after bufferization which is much easier(https://mlir.llvm.org/docs/Passes/#-scf-parallel-loop-tiling) as you don't need to handle the things like extract_slice and parallel_insert_slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tileConsumerAndFuseProducersUsingSCF uses the tiles calculation function, provided in the options. There are 2 implementations - one for static and one for dynamic sizes. On the first run of tileConsumerAndFuseProducersUsingSCF the static version will be used, but if dynamic shapes appeared, the dynamic version will be used and the result will be different.

opts.tilingOptions.setLoopType(SCFTilingOptions::LoopType::ForallOp);

for (auto ti = findTi(rw, fn); ti; ti = findTi(rw, fn)) {
auto result = tileConsumerAndFuseProducersUsingSCF(rw, *ti, opts);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use tileConsumerAndFuseProducersUsingSCF instead of tileProducerAndFuseConsumerUsingSCF here and how to choose the tiling target op here? For the fusion matmul + add + relu, we should tile the relu as the add and matmul is relu's producer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tileConsumerAndFuseProducersUsingSCF() is the upstream function, we supply it with the tiling computation and fusion control functions, passed in the options.

Copy link
Member

@zhczhong zhczhong Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fuse consumer and fuse producer are all upstream avaliable function(https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredfuse_into_containing_op-transformfuseintocontainingop, https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredfuse-transformfuseop, llvm/llvm-project#88712). The point is why choose one instead of another and how to choose the target op. If you call fuse producer(fuse consumer), you need to find out the last op(first op) in the fusion chain, otherwise you will get a suboptimal result(only fuse part of them).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I choose tileConsumerAndFuseProducersUsingSCF instead of tileProducerAndFuseConsumerUsingSCF because there is no such function. The operations are processed in the bottom-up order, so we start from the last op.

Copy link
Member

@zhczhong zhczhong Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is not exactly the same. The function name should be tileAndFuseConsumerOfSlice in llvm/llvm-project#88712. Yes, we should start from the last op but how do you decide which op is the last op. For example, there will be a branch in the graph and how do you determine what op is the last op(add or relu) and this is part of the reason why the tile consumer is proposed(https://discourse.llvm.org/t/rfc-tiling-interface-supports-fuse-consumer/76286).

        ->add  
      /
matmul -> relu 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case matmul could be fused into each branch. Look at the following examples and compare the results. GpuTilingAndFusion moves all the operations into the branches, but IterativeTilingAndFusion does not fuse at all.

Example 1
  func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %cond: i1) {
    %t0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
    %t1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
    %tmp1 = tensor.empty() : tensor<256x256xf16>
    %0 = linalg.matmul ins(%t0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp1 : tensor<256x256xf16>) -> tensor<256x256xf16>
    %1 = scf.if %cond -> tensor<256x256xf16> {
      %tmp2 = tensor.empty() : tensor<256x256xf16>
      %2 = linalg.add ins(%0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp2 : tensor<256x256xf16>) -> tensor<256x256xf16>
      scf.yield %2 : tensor<256x256xf16>
    } else {
      %tmp3 = tensor.empty() : tensor<256x256xf16>
      %3 = linalg.mul ins(%0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp3 : tensor<256x256xf16>) -> tensor<256x256xf16>
      scf.yield %3 : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %1 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }
GpuTilingAndFusion
  func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
    %0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
    %2 = scf.if %arg3 -> (tensor<256x256xf16>) {
      %3 = tensor.empty() : tensor<256x256xf16>
      %4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %3) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
        %5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
          %6 = arith.addi %arg7, %arg5 : index
          %extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
          %extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
          %7 = tensor.empty() : tensor<16x32xf16>
          %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
          %extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
          %extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
          %9 = linalg.add ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
          scf.forall.in_parallel {
            tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
          }
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %4 : tensor<256x256xf16>
    } else {
      %3 = tensor.empty() : tensor<256x256xf16>
      %4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %3) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
        %5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
          %6 = arith.addi %arg7, %arg5 : index
          %extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
          %extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
          %7 = tensor.empty() : tensor<16x32xf16>
          %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
          %extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
          %extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
          %9 = linalg.mul ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
          scf.forall.in_parallel {
            tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
          }
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %4 : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %2 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }
IterativeTilingAndFusion
  func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
    %0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
    %2 = tensor.empty() : tensor<256x256xf16>
    %3 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %2) -> (tensor<256x256xf16>) {
      %extracted_slice = tensor.extract_slice %0[%arg5, 0] [32, 256] [1, 1] : tensor<256x256xf16> to tensor<32x256xf16>
      %extracted_slice_0 = tensor.extract_slice %1[0, %arg6] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
      %extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
      %5 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x256xf16>, tensor<256x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %5 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
      }
    }
    %4 = scf.if %arg3 -> (tensor<256x256xf16>) {
      %5 = tensor.empty() : tensor<256x256xf16>
      %6 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %5) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %3[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %7 = linalg.add ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %6 : tensor<256x256xf16>
    } else {
      %5 = tensor.empty() : tensor<256x256xf16>
      %6 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %5) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %3[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %7 = linalg.mul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %6 : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %4 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }
Example 2
  func.func @test2(%memref0: memref<256x256xf16>, %memref1: memref<256x256xf16>, %memref2: memref<256x256xf16>, %cond: i1) {
    %arg0 = bufferization.to_tensor %memref0 restrict : memref<256x256xf16>
    %arg1 = bufferization.to_tensor %memref1 restrict : memref<256x256xf16>
    %tmp1 = tensor.empty() : tensor<256x256xf16>
    %matmul_res = linalg.matmul ins(%arg0, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp1 : tensor<256x256xf16>) -> tensor<256x256xf16>
    %tmp2 = tensor.empty() : tensor<256x256xf16>
    %add_res = linalg.add ins(%matmul_res, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp2 : tensor<256x256xf16>) -> tensor<256x256xf16>
    %if_res = scf.if %cond -> tensor<256x256xf16> {
      scf.yield %add_res : tensor<256x256xf16>
    } else {
      %tmp3 = tensor.empty() : tensor<256x256xf16>
      %mul_res = linalg.mul ins(%add_res, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp3 : tensor<256x256xf16>) -> tensor<256x256xf16>
      scf.yield %mul_res : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %if_res in restrict writable %memref2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }
GpuTilingAndFusion
  func.func @test2(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
    %0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
    %2 = tensor.empty() : tensor<256x256xf16>
    %3 = scf.if %arg3 -> (tensor<256x256xf16>) {
      %4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %2) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
        %5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
          %6 = arith.addi %arg7, %arg5 : index
          %extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
          %extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
          %7 = tensor.empty() : tensor<16x32xf16>
          %8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
          %extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
          %extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
          %9 = linalg.add ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
          scf.forall.in_parallel {
            tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
          }
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %4 : tensor<256x256xf16>
    } else {
      %4 = tensor.empty() : tensor<256x256xf16>
      %5 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %4) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
        %6 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
          %7 = arith.addi %arg7, %arg5 : index
          %extracted_slice_0 = tensor.extract_slice %0[%7, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
          %extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
          %8 = tensor.empty() : tensor<16x32xf16>
          %9 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%8 : tensor<16x32xf16>) -> tensor<16x32xf16>
          %extracted_slice_2 = tensor.extract_slice %1[%7, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
          %10 = tensor.empty() : tensor<16x32xf16>
          %11 = linalg.add ins(%9, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%10 : tensor<16x32xf16>) -> tensor<16x32xf16>
          %extracted_slice_3 = tensor.extract_slice %1[%7, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
          %12 = linalg.mul ins(%11, %extracted_slice_3 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_4 : tensor<16x32xf16>) -> tensor<16x32xf16>
          scf.forall.in_parallel {
            tensor.parallel_insert_slice %12 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
          }
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %6 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %5 : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }
IterativeTilingAndFusion
  func.func @test2(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
    %0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
    %1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
    %2 = tensor.empty() : tensor<256x256xf16>
    %3 = tensor.empty() : tensor<256x256xf16>
    %4 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %3) -> (tensor<256x256xf16>) {
      %extracted_slice = tensor.extract_slice %0[%arg5, 0] [32, 256] [1, 1] : tensor<256x256xf16> to tensor<32x256xf16>
      %extracted_slice_0 = tensor.extract_slice %1[0, %arg6] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
      %extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
      %6 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x256xf16>, tensor<256x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
      %extracted_slice_2 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
      %extracted_slice_3 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
      %7 = linalg.add ins(%6, %extracted_slice_2 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_3 : tensor<32x32xf16>) -> tensor<32x32xf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
      }
    }
    %5 = scf.if %arg3 -> (tensor<256x256xf16>) {
      scf.yield %4 : tensor<256x256xf16>
    } else {
      %6 = tensor.empty() : tensor<256x256xf16>
      %7 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %6) -> (tensor<256x256xf16>) {
        %extracted_slice = tensor.extract_slice %4[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
        %8 = linalg.mul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %8 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
        }
      }
      scf.yield %7 : tensor<256x256xf16>
    }
    bufferization.materialize_in_destination %5 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
    return
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not talking about the same thing.

  1. I am asking why use tileConsumerAndFuseProducersUsingSCF only and not use tileAndFuseConsumerOfSlice instead of IterativeTilingAndFusion.(this is the upstream method proposed by IREE guys and not the downstream one)
  2. The branch I mentioned is not if-else branch but more like the residual branch in resnet. But indeed, your implementation performs better at the if-else branch case.
    C
  /  \
D      E
C = matmul(A, B)
D = ReLU(C)
E = LayerNorm(C)

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't thoroughly looked into the logic, but overall looks good, modulo some comments.

template <typename T> static T floorPow2(T value) {
auto v = static_cast<std::make_unsigned_t<T>>(value);
return T(1) << llvm::bit_width(v) - 1;
}

// Round to the smallest power of 2 that is >= value.
template <typename T> static T ceilPow2(T value) {
auto v = static_cast<std::make_unsigned_t<T>>(value);
return llvm::bit_ceil(v);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these covered by llvm's MathExtras.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is, but all these functions operate on unsigned integers and we often need to cast. This is just a shorthand for cast + llvm::bit_ceil().

Comment on lines +209 to +309
// Check recursively if the specified operation has an operand that
// depends on a result of a previous operation, matching the predicate.
template <unsigned MaxDepth = std::numeric_limits<unsigned>::max()>
bool isOperandDependsOnOp(bool (*predicate)(Operation *), Operation *operation,
unsigned depth = 0) {
for (auto operand : operation->getOperands()) {
if (auto op = operand.getDefiningOp();
op &&
(predicate(op) || (depth < MaxDepth &&
isOperandDependsOnOp(predicate, op, depth + 1)))) {
return true;
}
}
return false;
}

// Check recursively if there are any operation, matching the predicate, that
// depends on the result of the specified operation.
template <unsigned MaxDepth = std::numeric_limits<unsigned>::max()>
bool isOpDependsOnResult(bool (*predicate)(Operation *), Operation *operation,
unsigned depth = 0) {
for (auto res : operation->getResults()) {
for (auto u : res.getUsers()) {
if (predicate(u) ||
(depth < MaxDepth && isOpDependsOnResult(predicate, u, depth + 1))) {
return true;
}
}
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it can get pretty expensive. Can it be replaced by some standard data flow analysis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it could be expensive. These functions are called from the fusion control function to prevent the matmuls fusion. Do you have an idea how to use the standard data flow analysis for this case?

@AndreyPavlenko AndreyPavlenko marked this pull request as draft November 22, 2024 02:11
@AndreyPavlenko AndreyPavlenko force-pushed the gpu-tile branch 3 times, most recently from da9463a to 534d9ff Compare November 23, 2024 20:13
@AndreyPavlenko AndreyPavlenko marked this pull request as ready for review November 23, 2024 21:15
Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go, once comments are resolved.

@kurapov-peter kurapov-peter merged commit 3c716fe into intel:main Nov 26, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants