-
Notifications
You must be signed in to change notification settings - Fork 17
Implemented tiling and fusion path for GPU #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
04326d8
to
dbaea96
Compare
7656d44
to
db5fbc1
Compare
7466a14
to
acc2dec
Compare
8faf654
to
def5e26
Compare
3e29f3b
to
3597a30
Compare
There are currently two tests failing:
The first one can be fixed by reducing the number of work group down to 16 (a hacky one, but will work for now until we figure out a proper fix): diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
index cb3f5972..4e0f1265 100644
--- a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
+++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096.mlir
@@ -1,6 +1,6 @@
// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s
-module {
+module attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"GPU" : #dlti.target_device_spec<#dlti.dl_entry<"max_work_group_size", 16 : i64>>>} {
func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>,
%arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) {
%cst = arith.constant 0.000000e+00 : f16 The second one fails because one of the inputs of the second f16_mlp_32x4096x4096x4096_transpose.mlir with NEW tilingBoth matmuls are placed into a single gpu kernel: func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>) {
%c128 = arith.constant 128 : index
%c4096 = arith.constant 4096 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = memref.get_global @__constant_8x4096xf16 : memref<8x4096xf16>
%1 = memref.get_global @__constant_8x128xf16 : memref<8x128xf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg5, %arg6, %arg7) in (%arg11 = %c2, %arg12 = %c1, %arg13 = %c1) threads(%arg8, %arg9, %arg10) in (%arg14 = %c2, %arg15 = %c32, %arg16 = %c1) {
%2 = affine.apply #map(%arg5)
%subview_0 = memref.subview %alloc[%2, 0] [16, 4096] [1, 1] : memref<32x4096xf16> to memref<16x4096xf16, strided<[4096, 1], offset: ?>>
%3 = affine.apply #map1(%arg8)
%4 = affine.apply #map2(%arg9)
%5 = arith.addi %3, %2 : index
%subview_1 = memref.subview %arg4[%5, %4] [8, 128] [1, 1] : memref<32x4096xf16> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
%subview_2 = memref.subview %arg2[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
%subview_3 = memref.subview %arg0[%5, 0] [8, 4096] [1, 1] : memref<32x4096xf16> to memref<8x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_4 = memref.alloc() : memref<16x131072xf16, 3>
%6 = arith.muli %arg8, %c8 : index
%7 = arith.muli %arg9, %c4096 : index
%subview_5 = memref.subview %alloc_4[%6, %7] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.fill ins(%cst : f16) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
linalg.matmul_transpose_b ins(%subview_3, %arg1 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<4096x4096xf16>) outs(%subview_5 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%alloc_6 = memref.alloc() : memref<16x131072xf16, 3>
%8 = arith.muli %arg8, %c8 : index
%9 = arith.muli %arg9, %c4096 : index
%subview_7 = memref.subview %alloc_6[%8, %9] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.add ins(%subview_2, %subview_5 : memref<8x4096xf16, strided<[4096, 1], offset: ?>>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_7 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%alloc_8 = memref.alloc() : memref<16x131072xf16, 3>
%10 = arith.muli %arg8, %c8 : index
%11 = arith.muli %arg9, %c4096 : index
%subview_9 = memref.subview %alloc_8[%10, %11] [8, 4096] [1, 1] : memref<16x131072xf16, 3> to memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>
linalg.max ins(%0, %subview_7 : memref<8x4096xf16>, memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>) outs(%subview_9 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>)
%subview_10 = memref.subview %arg3[%4, 0] [128, 4096] [1, 1] : memref<4096x4096xf16> to memref<128x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x4096xf16, 3>
%12 = arith.muli %arg8, %c8 : index
%13 = arith.muli %arg9, %c128 : index
%subview_12 = memref.subview %alloc_11[%12, %13] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
linalg.fill ins(%cst : f16) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
linalg.matmul_transpose_b ins(%subview_9, %subview_10 : memref<8x4096xf16, strided<[131072, 1], offset: ?>, 3>, memref<128x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
%alloc_13 = memref.alloc() : memref<16x4096xf16, 3>
%14 = arith.muli %arg8, %c8 : index
%15 = arith.muli %arg9, %c128 : index
%subview_14 = memref.subview %alloc_13[%14, %15] [8, 128] [1, 1] : memref<16x4096xf16, 3> to memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>
linalg.add ins(%subview_1, %subview_12 : memref<8x128xf16, strided<[4096, 1], offset: ?>>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_14 : memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>)
%subview_15 = memref.subview %subview_0[%3, %4] [8, 128] [1, 1] : memref<16x4096xf16, strided<[4096, 1], offset: ?>> to memref<8x128xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%1, %subview_14 : memref<8x128xf16>, memref<8x128xf16, strided<[4096, 1], offset: ?>, 3>) outs(%subview_15 : memref<8x128xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%subview = memref.subview %alloc[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
%cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
memref.dealloc %alloc : memref<32x4096xf16>
return
} f16_mlp_32x4096x4096x4096_transpose.mlir with OLD tilingTwo dependent matmuls are placed into separate kernels func.func @linalg_mlp(%arg0: memref<32x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<32x4096xf16>, %arg3: memref<4096x4096xf16>, %arg4: memref<32x4096xf16>, %arg5: memref<i8>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
%1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
%subview_5 = memref.subview %arg1[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
%subview_6 = memref.subview %alloc[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
linalg.matmul_transpose_b ins(%arg0, %subview_5 : memref<32x4096xf16>, memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_7 = memref.subview %arg2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
%subview_8 = memref.subview %alloc_0[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.add ins(%subview_7, %subview_6 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_9 = memref.subview %alloc_1[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%0, %subview_8 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_9 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x4096xf16>
gpu.launch blocks(%arg6, %arg7, %arg8) in (%arg12 = %c128, %arg13 = %c1, %arg14 = %c1) threads(%arg9, %arg10, %arg11) in (%arg15 = %c1, %arg16 = %c1, %arg17 = %c1) {
%1 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6)
%subview_5 = memref.subview %arg3[%1, 0] [32, 4096] [1, 1] : memref<4096x4096xf16> to memref<32x4096xf16, strided<[4096, 1], offset: ?>>
%alloc_6 = memref.alloc() : memref<4096x32xf16, 3>
%2 = arith.muli %arg9, %c4096 : index
%3 = arith.muli %arg10, %c32 : index
%subview_7 = memref.subview %alloc_6[%2, %3] [4096, 32] [1, 1] : memref<4096x32xf16, 3> to memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>
linalg.transpose ins(%subview_5 : memref<32x4096xf16, strided<[4096, 1], offset: ?>>) outs(%subview_7 : memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) permutation = [1, 0]
%subview_8 = memref.subview %alloc_2[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
linalg.matmul ins(%alloc_1, %subview_7 : memref<32x4096xf16>, memref<4096x32xf16, strided<[32, 1], offset: ?>, 3>) outs(%subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_9 = memref.subview %arg4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
%subview_10 = memref.subview %alloc_3[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.add ins(%subview_9, %subview_8 : memref<32x32xf16, strided<[4096, 1], offset: ?>>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_10 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
%subview_11 = memref.subview %alloc_4[0, %1] [32, 32] [1, 1] : memref<32x4096xf16> to memref<32x32xf16, strided<[4096, 1], offset: ?>>
linalg.max ins(%0, %subview_10 : memref<32x32xf16>, memref<32x32xf16, strided<[4096, 1], offset: ?>>) outs(%subview_11 : memref<32x32xf16, strided<[4096, 1], offset: ?>>)
gpu.terminator
} {SCFToGPU_visited}
%subview = memref.subview %alloc_4[0, 0] [32, 2] [1, 1] : memref<32x4096xf16> to memref<32x2xf16, strided<[4096, 1]>>
%cast = memref.cast %subview : memref<32x2xf16, strided<[4096, 1]>> to memref<*xf16>
call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
memref.dealloc %alloc : memref<32x4096xf16>
memref.dealloc %alloc_0 : memref<32x4096xf16>
memref.dealloc %alloc_1 : memref<32x4096xf16>
memref.dealloc %alloc_2 : memref<32x4096xf16>
memref.dealloc %alloc_3 : memref<32x4096xf16>
memref.dealloc %alloc_4 : memref<32x4096xf16>
return
} The non-transposed case (f16_mlp_32x4096x4096x4096.mlir) works absolutely fine since two dependent matmuls are split into two separate kernels as expected, it's only the transposed one that causes the problem |
3597a30
to
88a7b26
Compare
Now the matmuls are split into 2 kernels and boths tests fail due to |
522f3e7
to
bddcd97
Compare
The error |
bddcd97
to
b5ab841
Compare
|
||
OpRewriter rw(fn); | ||
tileAndFuseLinalgOps(rw, fn); | ||
tileForallOps(rw, fn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why separate the tiling and fusion into two phases and the second phase seems to have no fusion inside from the naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second phase just creates a nested loop, there is nothing to fuse, it just tiles the outer loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I got it. Then why do you tile the outer loop manually instead of calling the tileandFuse
twice?
linalg.op
tileAndFuse linalg.op:
scf.forall {
linalg.op
}
tileAndFuse linalg.op:
scf.forall {
scf.forall {
linalg.op
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we need to tile a single loop just to spread computations over the work-groups. We don't need to tile and fuse all the operations inside the loop. Actually, I tried it in the first approach, it worked for simple cases, like in your example, but for some more complex cases the result was not as it was expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do it even there are multiple linalg.op as this is the upstream recommended way. Could you show me this complex case that the result is not expected calling the tileAndFuse twice?
linalg.op
linalg.op
linalg.opx
tileConsumerAndFuseProducersUsingSCF linalg.opx:
scf.forall {
linalg.op
linalg.op
linalg.opx
}
tileConsumerAndFuseProducersUsingSCF linalg.opx:
scf.forall {
scf.forall {
linalg.op
linalg.op
linalg.opx
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, when after the first tiling dynamic shapes appear. The second tiling will produce a different result. Besides, the single loop tiling requires much less computation, than full tiling and fusion of all the ops inside the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean by different result? As far as I know, the tileConsumerAndFuseProducersUsingSCF
will insert affine.min, affine.max, affine.apply
to ensure the tiling bound is valid if the dynamic shape exists. The extra code is for safety. There should be no difference in the fusion as the fusion only cares about the slice def-use chain.
If you want to tile scf.forall only, you could follow the scf-parallel-loop-tiling
to do it after bufferization which is much easier(https://mlir.llvm.org/docs/Passes/#-scf-parallel-loop-tiling) as you don't need to handle the things like extract_slice
and parallel_insert_slice
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tileConsumerAndFuseProducersUsingSCF uses the tiles calculation function, provided in the options. There are 2 implementations - one for static and one for dynamic sizes. On the first run of tileConsumerAndFuseProducersUsingSCF the static version will be used, but if dynamic shapes appeared, the dynamic version will be used and the result will be different.
opts.tilingOptions.setLoopType(SCFTilingOptions::LoopType::ForallOp); | ||
|
||
for (auto ti = findTi(rw, fn); ti; ti = findTi(rw, fn)) { | ||
auto result = tileConsumerAndFuseProducersUsingSCF(rw, *ti, opts); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use tileConsumerAndFuseProducersUsingSCF
instead of tileProducerAndFuseConsumerUsingSCF
here and how to choose the tiling target op here? For the fusion matmul + add + relu
, we should tile the relu as the add
and matmul
is relu's producer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tileConsumerAndFuseProducersUsingSCF()
is the upstream function, we supply it with the tiling computation and fusion control functions, passed in the options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fuse consumer
and fuse producer
are all upstream avaliable function(https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredfuse_into_containing_op-transformfuseintocontainingop, https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredfuse-transformfuseop, llvm/llvm-project#88712). The point is why choose one instead of another and how to choose the target op. If you call fuse producer
(fuse consumer
), you need to find out the last op(first op) in the fusion chain, otherwise you will get a suboptimal result(only fuse part of them).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I choose tileConsumerAndFuseProducersUsingSCF instead of tileProducerAndFuseConsumerUsingSCF because there is no such function. The operations are processed in the bottom-up order, so we start from the last op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is not exactly the same. The function name should be tileAndFuseConsumerOfSlice
in llvm/llvm-project#88712. Yes, we should start from the last op but how do you decide which op is the last op. For example, there will be a branch in the graph and how do you determine what op is the last op(add
or relu
) and this is part of the reason why the tile consumer
is proposed(https://discourse.llvm.org/t/rfc-tiling-interface-supports-fuse-consumer/76286).
->add
/
matmul -> relu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case matmul could be fused into each branch. Look at the following examples and compare the results. GpuTilingAndFusion moves all the operations into the branches, but IterativeTilingAndFusion does not fuse at all.
Example 1
func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %cond: i1) {
%t0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
%t1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
%tmp1 = tensor.empty() : tensor<256x256xf16>
%0 = linalg.matmul ins(%t0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp1 : tensor<256x256xf16>) -> tensor<256x256xf16>
%1 = scf.if %cond -> tensor<256x256xf16> {
%tmp2 = tensor.empty() : tensor<256x256xf16>
%2 = linalg.add ins(%0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp2 : tensor<256x256xf16>) -> tensor<256x256xf16>
scf.yield %2 : tensor<256x256xf16>
} else {
%tmp3 = tensor.empty() : tensor<256x256xf16>
%3 = linalg.mul ins(%0, %t1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp3 : tensor<256x256xf16>) -> tensor<256x256xf16>
scf.yield %3 : tensor<256x256xf16>
}
bufferization.materialize_in_destination %1 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
GpuTilingAndFusion
func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
%1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
%2 = scf.if %arg3 -> (tensor<256x256xf16>) {
%3 = tensor.empty() : tensor<256x256xf16>
%4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %3) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
%5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
%6 = arith.addi %arg7, %arg5 : index
%extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
%extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%7 = tensor.empty() : tensor<16x32xf16>
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
%extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
%extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
%9 = linalg.add ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
}
}
scf.yield %4 : tensor<256x256xf16>
} else {
%3 = tensor.empty() : tensor<256x256xf16>
%4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %3) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
%5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
%6 = arith.addi %arg7, %arg5 : index
%extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
%extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%7 = tensor.empty() : tensor<16x32xf16>
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
%extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
%extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
%9 = linalg.mul ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
}
}
scf.yield %4 : tensor<256x256xf16>
}
bufferization.materialize_in_destination %2 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
IterativeTilingAndFusion
func.func @test1(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
%1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
%2 = tensor.empty() : tensor<256x256xf16>
%3 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %2) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %0[%arg5, 0] [32, 256] [1, 1] : tensor<256x256xf16> to tensor<32x256xf16>
%extracted_slice_0 = tensor.extract_slice %1[0, %arg6] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%5 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x256xf16>, tensor<256x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
}
}
%4 = scf.if %arg3 -> (tensor<256x256xf16>) {
%5 = tensor.empty() : tensor<256x256xf16>
%6 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %5) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %3[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%7 = linalg.add ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
}
}
scf.yield %6 : tensor<256x256xf16>
} else {
%5 = tensor.empty() : tensor<256x256xf16>
%6 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %5) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %3[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%7 = linalg.mul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
}
}
scf.yield %6 : tensor<256x256xf16>
}
bufferization.materialize_in_destination %4 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
Example 2
func.func @test2(%memref0: memref<256x256xf16>, %memref1: memref<256x256xf16>, %memref2: memref<256x256xf16>, %cond: i1) {
%arg0 = bufferization.to_tensor %memref0 restrict : memref<256x256xf16>
%arg1 = bufferization.to_tensor %memref1 restrict : memref<256x256xf16>
%tmp1 = tensor.empty() : tensor<256x256xf16>
%matmul_res = linalg.matmul ins(%arg0, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp1 : tensor<256x256xf16>) -> tensor<256x256xf16>
%tmp2 = tensor.empty() : tensor<256x256xf16>
%add_res = linalg.add ins(%matmul_res, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp2 : tensor<256x256xf16>) -> tensor<256x256xf16>
%if_res = scf.if %cond -> tensor<256x256xf16> {
scf.yield %add_res : tensor<256x256xf16>
} else {
%tmp3 = tensor.empty() : tensor<256x256xf16>
%mul_res = linalg.mul ins(%add_res, %arg1 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%tmp3 : tensor<256x256xf16>) -> tensor<256x256xf16>
scf.yield %mul_res : tensor<256x256xf16>
}
bufferization.materialize_in_destination %if_res in restrict writable %memref2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
GpuTilingAndFusion
func.func @test2(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
%1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
%2 = tensor.empty() : tensor<256x256xf16>
%3 = scf.if %arg3 -> (tensor<256x256xf16>) {
%4 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %2) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
%5 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
%6 = arith.addi %arg7, %arg5 : index
%extracted_slice_0 = tensor.extract_slice %0[%6, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
%extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%7 = tensor.empty() : tensor<16x32xf16>
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%7 : tensor<16x32xf16>) -> tensor<16x32xf16>
%extracted_slice_2 = tensor.extract_slice %1[%6, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
%extracted_slice_3 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
%9 = linalg.add ins(%8, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_3 : tensor<16x32xf16>) -> tensor<16x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %9 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
}
}
scf.yield %4 : tensor<256x256xf16>
} else {
%4 = tensor.empty() : tensor<256x256xf16>
%5 = scf.forall (%arg5) = (0) to (256) step (128) shared_outs(%arg6 = %4) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<256x256xf16> to tensor<128x256xf16>
%6 = scf.forall (%arg7, %arg8) = (0, 0) to (128, 256) step (16, 32) shared_outs(%arg9 = %extracted_slice) -> (tensor<128x256xf16>) {
%7 = arith.addi %arg7, %arg5 : index
%extracted_slice_0 = tensor.extract_slice %0[%7, 0] [16, 256] [1, 1] : tensor<256x256xf16> to tensor<16x256xf16>
%extracted_slice_1 = tensor.extract_slice %1[0, %arg8] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%8 = tensor.empty() : tensor<16x32xf16>
%9 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x256xf16>, tensor<256x32xf16>) outs(%8 : tensor<16x32xf16>) -> tensor<16x32xf16>
%extracted_slice_2 = tensor.extract_slice %1[%7, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
%10 = tensor.empty() : tensor<16x32xf16>
%11 = linalg.add ins(%9, %extracted_slice_2 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%10 : tensor<16x32xf16>) -> tensor<16x32xf16>
%extracted_slice_3 = tensor.extract_slice %1[%7, %arg8] [16, 32] [1, 1] : tensor<256x256xf16> to tensor<16x32xf16>
%extracted_slice_4 = tensor.extract_slice %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<128x256xf16> to tensor<16x32xf16>
%12 = linalg.mul ins(%11, %extracted_slice_3 : tensor<16x32xf16>, tensor<16x32xf16>) outs(%extracted_slice_4 : tensor<16x32xf16>) -> tensor<16x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %12 into %arg9[%arg7, %arg8] [16, 32] [1, 1] : tensor<16x32xf16> into tensor<128x256xf16>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %6 into %arg6[%arg5, 0] [128, 256] [1, 1] : tensor<128x256xf16> into tensor<256x256xf16>
}
}
scf.yield %5 : tensor<256x256xf16>
}
bufferization.materialize_in_destination %3 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
IterativeTilingAndFusion
func.func @test2(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf16>, %arg3: i1, %arg4: memref<i8>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<256x256xf16>
%1 = bufferization.to_tensor %arg1 restrict : memref<256x256xf16>
%2 = tensor.empty() : tensor<256x256xf16>
%3 = tensor.empty() : tensor<256x256xf16>
%4 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %3) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %0[%arg5, 0] [32, 256] [1, 1] : tensor<256x256xf16> to tensor<32x256xf16>
%extracted_slice_0 = tensor.extract_slice %1[0, %arg6] [256, 32] [1, 1] : tensor<256x256xf16> to tensor<256x32xf16>
%extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%6 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x256xf16>, tensor<256x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
%extracted_slice_2 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_3 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%7 = linalg.add ins(%6, %extracted_slice_2 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_3 : tensor<32x32xf16>) -> tensor<32x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
}
}
%5 = scf.if %arg3 -> (tensor<256x256xf16>) {
scf.yield %4 : tensor<256x256xf16>
} else {
%6 = tensor.empty() : tensor<256x256xf16>
%7 = scf.forall (%arg5, %arg6) = (0, 0) to (256, 256) step (32, 32) shared_outs(%arg7 = %6) -> (tensor<256x256xf16>) {
%extracted_slice = tensor.extract_slice %4[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_0 = tensor.extract_slice %1[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%extracted_slice_1 = tensor.extract_slice %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<256x256xf16> to tensor<32x32xf16>
%8 = linalg.mul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32xf16>, tensor<32x32xf16>) outs(%extracted_slice_1 : tensor<32x32xf16>) -> tensor<32x32xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %8 into %arg7[%arg5, %arg6] [32, 32] [1, 1] : tensor<32x32xf16> into tensor<256x256xf16>
}
}
scf.yield %7 : tensor<256x256xf16>
}
bufferization.materialize_in_destination %5 in restrict writable %arg2 : (tensor<256x256xf16>, memref<256x256xf16>) -> ()
return
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not talking about the same thing.
- I am asking why use
tileConsumerAndFuseProducersUsingSCF
only and not usetileAndFuseConsumerOfSlice
instead ofIterativeTilingAndFusion
.(this is the upstream method proposed by IREE guys and not the downstream one) - The branch I mentioned is not if-else branch but more like the residual branch in resnet. But indeed, your implementation performs better at the if-else branch case.
C
/ \
D E
C = matmul(A, B)
D = ReLU(C)
E = LayerNorm(C)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't thoroughly looked into the logic, but overall looks good, modulo some comments.
lib/gc/Transforms/GPU/GpuUtils.h
Outdated
template <typename T> static T floorPow2(T value) { | ||
auto v = static_cast<std::make_unsigned_t<T>>(value); | ||
return T(1) << llvm::bit_width(v) - 1; | ||
} | ||
|
||
// Round to the smallest power of 2 that is >= value. | ||
template <typename T> static T ceilPow2(T value) { | ||
auto v = static_cast<std::make_unsigned_t<T>>(value); | ||
return llvm::bit_ceil(v); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't these covered by llvm's MathExtras.h?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is, but all these functions operate on unsigned integers and we often need to cast. This is just a shorthand for cast + llvm::bit_ceil().
// Check recursively if the specified operation has an operand that | ||
// depends on a result of a previous operation, matching the predicate. | ||
template <unsigned MaxDepth = std::numeric_limits<unsigned>::max()> | ||
bool isOperandDependsOnOp(bool (*predicate)(Operation *), Operation *operation, | ||
unsigned depth = 0) { | ||
for (auto operand : operation->getOperands()) { | ||
if (auto op = operand.getDefiningOp(); | ||
op && | ||
(predicate(op) || (depth < MaxDepth && | ||
isOperandDependsOnOp(predicate, op, depth + 1)))) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
// Check recursively if there are any operation, matching the predicate, that | ||
// depends on the result of the specified operation. | ||
template <unsigned MaxDepth = std::numeric_limits<unsigned>::max()> | ||
bool isOpDependsOnResult(bool (*predicate)(Operation *), Operation *operation, | ||
unsigned depth = 0) { | ||
for (auto res : operation->getResults()) { | ||
for (auto u : res.getUsers()) { | ||
if (predicate(u) || | ||
(depth < MaxDepth && isOpDependsOnResult(predicate, u, depth + 1))) { | ||
return true; | ||
} | ||
} | ||
} | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it can get pretty expensive. Can it be replaced by some standard data flow analysis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it could be expensive. These functions are called from the fusion control function to prevent the matmuls fusion. Do you have an idea how to use the standard data flow analysis for this case?
b5ab841
to
62f4333
Compare
da9463a
to
534d9ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go, once comments are resolved.
534d9ff
to
ed0ef95
Compare
ed0ef95
to
2ab7c3e
Compare
This path creates 2 nested loops for linalg operations, that are later converted to gpu.launch.
The outer loop is mapped to to the grid sizes and the inner loop is mapped to the block sizes.
The tiles calculation is based on the device information retrieved either from the module DLTI attributes or from the path options.
Depends on #406