Skip to content

Commit 0478401

Browse files
author
Mahesh Ravishankar
committed
[mlir][TilingInterface] Add test for tile + fuse of sequence of reductions.
This just adds a test. With CSE of single block ops, and other previously landed changes, this works at HEAD. Just adding a test that triggered this line of work that I missed adding. Differential Revision: https://reviews.llvm.org/D139385
1 parent cfd7318 commit 0478401

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse -split-input-file %s | FileCheck %s
22

33
func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
44
%c0 = arith.constant 0 : index
@@ -271,18 +271,12 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
271271
// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
272272
// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
273273
// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
274-
// CHECK: %[[LHS:.+]] = linalg.matmul
274+
// CHECK: %[[MATMUL:.+]] = linalg.matmul
275275
// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] :
276276
// CHECK-SAME: outs(%[[ST_ARG2]] :
277-
// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
278-
// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
279-
// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
280-
// CHECK: %[[RHS:.+]] = linalg.matmul
281-
// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] :
282-
// CHECK-SAME: outs(%[[ST_ARG2_1]] :
283277
// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
284278
// CHECK: %[[ST_RESULT:.+]] = linalg.generic
285-
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
279+
// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] :
286280
// CHECK-SAME: outs(%[[ST_ARG6]] :
287281
// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
288282
// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]]
@@ -401,3 +395,69 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
401395
// CHECK-SAME: outs(%[[SLICE_ARG6]] :
402396
// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
403397
// CHECK: scf.yield %[[UPDATE]]
398+
399+
// -----
400+
401+
func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
402+
%cst = arith.constant 0.000000e+00 : f32
403+
%cst_0 = arith.constant 0xFF800000 : f32
404+
%0 = tensor.empty() : tensor<30xf32>
405+
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
406+
%2 = linalg.generic {
407+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
408+
iterator_types = ["parallel", "reduction"]}
409+
ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
410+
^bb0(%arg1: f32, %arg2: f32):
411+
%8 = arith.maxf %arg2, %arg1 : f32
412+
linalg.yield %8 : f32
413+
} -> tensor<30xf32>
414+
%3 = tensor.empty() : tensor<30x3xf32>
415+
%4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
416+
%5:2 = linalg.generic {
417+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
418+
affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
419+
iterator_types = ["parallel", "reduction"]}
420+
ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
421+
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
422+
%8 = arith.subf %arg1, %arg2 : f32
423+
%9 = math.exp %8 : f32
424+
%10 = arith.addf %arg3, %9 : f32
425+
linalg.yield %10, %9 : f32, f32
426+
} -> (tensor<30xf32>, tensor<30x3xf32>)
427+
%6 = linalg.generic {
428+
__internal_linalg_transform__ = "reduction_sequence_fusion",
429+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
430+
affine_map<(d0, d1) -> (d0, d1)>],
431+
iterator_types = ["parallel", "parallel"]}
432+
ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
433+
^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
434+
%8 = arith.divf %arg1, %arg2 : f32
435+
linalg.yield %8 : f32
436+
} -> tensor<30x3xf32>
437+
return %6 : tensor<30x3xf32>
438+
}
439+
// CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
440+
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
441+
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
442+
// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]]
443+
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
444+
// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
445+
// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
446+
// CHECK: %[[FILL0:.+]] = linalg.fill
447+
// CHECK-SAME: outs(%[[INIT0_SLICE]] :
448+
// CHECK: %[[GENERIC0:.+]] = linalg.generic
449+
// CHECK-SAME: ins(%[[ARG0_SLICE]] :
450+
// CHECK-SAME: outs(%[[FILL0]] :
451+
// CHECK: %[[FILL1:.+]] = linalg.fill
452+
// CHECK-SAME: outs(%[[INIT0_SLICE]] :
453+
// CHECK: %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
454+
// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic
455+
// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
456+
// CHECK-SAME: outs(%[[FILL1]], %[[INIT1_SLICE]] :
457+
// CHECK: %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
458+
// CHECK: %[[GENERIC2:.+]] = linalg.generic
459+
// CHECK-SAME: ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
460+
// CHECK-SAME: outs(%[[ITERARG0_SLICE]] :
461+
// CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
462+
// CHECK: scf.yield %[[INSERTSLICE]]
463+
// CHECK: return %[[RESULT]]

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
368368
// 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M
369369
// dimension.
370370
addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10});
371+
// 6. Fusion of back-to-back-reduction ops
372+
addPatternForTileAndFuse(context, patterns, "reduction_sequence_fusion",
373+
{10});
371374
return;
372375
}
373376
if (testLoweringToScalar) {

0 commit comments

Comments
 (0)