Skip to content

Commit 111d276

Browse files
committed
fix comment and rename pass name
1 parent b548633 commit 111d276

File tree

5 files changed

+57
-66
lines changed

5 files changed

+57
-66
lines changed

include/gc/Transforms/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
5858
];
5959
}
6060

61-
def FineGrainedFusion : Pass<"fine-grained-fusion",
61+
def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
6262
"func::FuncOp"> {
63-
let summary = "Fine Grained Fusion for any tilable operation";
63+
let summary = "Iterative tiling and fusion for any tilable operation";
6464
let description = [{
6565
The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support:
6666
1. Matmul fusion with element-wise/reduce/broadcast ops.
@@ -72,7 +72,7 @@ def FineGrainedFusion : Pass<"fine-grained-fusion",
7272

7373
It intends to control the granularity of fusion by `fusion-level`, E.g.
7474
* `0`: disable any fusion.
75-
* `1`:[Default] enable pre-op fusion + post-op fusion covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses.
75+
* `1`:[Default] enable both producer and consumer fusion, covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses.
7676
* `2`: `LEVEL 1` + extend to any topology including branches.
7777
}];
7878
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ add_mlir_library(GCPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16-
FineGrainedFusion.cpp
16+
IterativeTilingAndFusion.cpp
1717
TilingUsingInterfaceX.cpp
1818

1919
ADDITIONAL_HEADER_DIRS

lib/gc/Transforms/FineGrainedFusion.cpp renamed to lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- FineGrainedFusion.cpp - Fine-Grained Fusion -------------*- C++ -*-===//
1+
//===-- IterativeTilingAndFusion.cpp - Iterative Tiling+Fusion --*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -33,7 +33,7 @@
3333

3434
namespace mlir {
3535
namespace gc {
36-
#define GEN_PASS_DEF_FINEGRAINEDFUSION
36+
#define GEN_PASS_DEF_ITERATIVETILINGANDFUSION
3737
#include "gc/Transforms/Passes.h.inc"
3838

3939
static FailureOr<tensor::ExtractSliceOp>
@@ -215,6 +215,15 @@ alreadyTiledOpFilter(RewriterBase &rewriter,
215215
return failure(defOrUse.ownerOp->use_empty());
216216
}
217217

218+
static LogicalResult
219+
NonContractionOpFilter(RewriterBase &rewriter,
220+
OffsetSizeAndStrideOpInterface candidate,
221+
CandidateDefOrUse defOrUse) {
222+
// Currently this pass focuses on fine-grained fusion, which does not expect
223+
// two consecutive contraction ops.
224+
return failure(isa<mlir::linalg::ContractionOpInterface>(defOrUse.ownerOp));
225+
}
226+
218227
static LogicalResult
219228
SingleCandidateInBlockFilter(RewriterBase &rewriter,
220229
OffsetSizeAndStrideOpInterface candidate,
@@ -289,7 +298,7 @@ struct CandidateSliceFilterPipeLine
289298

290299
SmallVector<CandidateSliceFilter> getDefaultPipeLine() {
291300
return SmallVector<CandidateSliceFilter>{
292-
alreadyTiledOpFilter, noTilingOnReductionFilter,
301+
alreadyTiledOpFilter, NonContractionOpFilter, noTilingOnReductionFilter,
293302
exactTilingOnPackUnPackFilter, SingleCandidateInBlockFilter};
294303
}
295304

@@ -499,9 +508,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
499508
// candidate slice in avoid of conflict with subsequent
500509
// `tileAndFuseConsumerOfSlice` get nest loops between next candidate
501510
// sliceOp and tiled producer.
502-
auto region = outerLoops.front()->getParentRegion();
503-
(void)mlir::eraseUnreachableBlocks(rewriter, {*region});
504-
(void)mlir::runRegionDCE(rewriter, {*region});
511+
(void)mlir::simplifyRegions(rewriter,
512+
{*outerLoops.front()->getParentRegion()});
505513
}
506514
}
507515
if (fusedResultList.empty()) {
@@ -575,7 +583,7 @@ static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
575583
// 2. check single one tiling interface in loop body
576584
auto walkResult = forOp->walk([&targetOp](TilingInterface op) {
577585
// some special op maybe already deal with in template
578-
if (isa<linalg::FillOp>(op))
586+
if (isa<linalg::FillOp, linalg::CopyOp>(op))
579587
return WalkResult::skip();
580588
return op != targetOp ? WalkResult::interrupt() : WalkResult::advance();
581589
});
@@ -632,10 +640,6 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
632640
}
633641
}
634642

635-
struct IterativeFusionOptions {
636-
bool useCostModel = false;
637-
};
638-
639643
struct SystemDesc {
640644
// get runtime OMP_NUM_THREADS
641645
uint32_t getNumThreads() {
@@ -692,8 +696,9 @@ struct SystemDesc {
692696
MLIRContext *ctx;
693697
};
694698

695-
void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
696-
const IterativeFusionOptions &fuseOptions) {
699+
void iterativeTilingAndFusionUntilExhaustion(
700+
RewriterBase &rewriter, func::FuncOp &f,
701+
const CandidateSliceOptions &sliceOptions) {
697702
// Collect untiled and tiled ops respectively
698703
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
699704

@@ -731,29 +736,6 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
731736
return !singleTiledOpInLoop.empty();
732737
};
733738

734-
SystemDesc sysDesc(f->getParentOfType<ModuleOp>());
735-
// Flexible options to control which candidate slice would be selected from
736-
// the view of both validity and performance.
737-
CandidateSliceOptions sliceOptions;
738-
// Since most filters regarding to validity have already been built-in
739-
// enabled. Users could focus on performance related filters, a.k.a. cost
740-
// model.
741-
if (fuseOptions.useCostModel) {
742-
// Customized filter by cost model.
743-
CandidateSliceFilter costModelFilter =
744-
[&sysDesc](RewriterBase &rewriter,
745-
OffsetSizeAndStrideOpInterface candidate,
746-
CandidateDefOrUse defOrUse) -> LogicalResult {
747-
// Get cache size
748-
size_t l2CacheSize = sysDesc.getCacheSize(2);
749-
FailureOr<int64_t> tileSizeProduct =
750-
computeTileSizeProductOfCandidate(candidate);
751-
return success(succeeded(tileSizeProduct) &&
752-
(*tileSizeProduct <= (int64_t)l2CacheSize));
753-
};
754-
sliceOptions.addFilter(costModelFilter);
755-
}
756-
757739
// Iterative tiling and fusion until exhaustion.
758740
while (collectUnTiledOps()) {
759741
// If existing tiled op before tiling.
@@ -768,12 +750,8 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
768750
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
769751
rewriter, tiledOp, sliceOptions));
770752
});
771-
if (!changed) {
772-
// If no new fusion happens, terminate iteration.
773-
break;
774-
} else {
775-
(void)mlir::eraseUnreachableBlocks(rewriter, {f.getRegion()});
776-
(void)mlir::runRegionDCE(rewriter, {f.getRegion()});
753+
if (changed) {
754+
(void)mlir::simplifyRegions(rewriter, {f.getRegion()});
777755
}
778756
} else {
779757
// Auto tiling with default tile size if no tiled op found. Follow tiling
@@ -798,29 +776,42 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
798776
}
799777
}
800778

801-
struct FineGrainedFusion
802-
: public impl::FineGrainedFusionBase<FineGrainedFusion> {
803-
using FineGrainedFusionBase::FineGrainedFusionBase;
779+
struct IterativeTilingAndFusion
780+
: public impl::IterativeTilingAndFusionBase<IterativeTilingAndFusion> {
781+
using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase;
804782

805783
public:
806784
void runOnOperation() final {
807785
auto &ctx = getContext();
808-
{
809-
// Get funcOp
810-
func::FuncOp func = getOperation();
811-
// Get rewriter
812-
IRRewriter rewriter(&ctx);
813-
// Run iterative fusion
814-
iterativeTilingAndFusion(rewriter, func,
815-
IterativeFusionOptions{useCostModel});
816-
}
817-
818-
{
819-
RewritePatternSet patternSet(&ctx);
820-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
821-
std::move(patternSet))))
822-
signalPassFailure();
786+
// Get funcOp
787+
func::FuncOp func = getOperation();
788+
// Get system descriptor
789+
SystemDesc sysDesc(func->getParentOfType<ModuleOp>());
790+
// Flexible options to control which candidate slice would be selected from
791+
// the view of both validity and performance.
792+
CandidateSliceOptions sliceOptions;
793+
// Since most filters regarding to validity have already been built-in
794+
// enabled. Users could focus on performance related filters, a.k.a. cost
795+
// model. E.g.
796+
if (useCostModel) {
797+
// Customized filter by cost model.
798+
CandidateSliceFilter costModelFilter =
799+
[&sysDesc](RewriterBase &rewriter,
800+
OffsetSizeAndStrideOpInterface candidate,
801+
CandidateDefOrUse defOrUse) -> LogicalResult {
802+
// Get cache size
803+
size_t l2CacheSize = sysDesc.getCacheSize(2);
804+
FailureOr<int64_t> tileSizeProduct =
805+
computeTileSizeProductOfCandidate(candidate);
806+
return success(succeeded(tileSizeProduct) &&
807+
(*tileSizeProduct <= (int64_t)l2CacheSize));
808+
};
809+
sliceOptions.addFilter(costModelFilter);
823810
}
811+
// Get rewriter
812+
IRRewriter rewriter(&ctx);
813+
// Run iterative fusion
814+
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions);
824815
}
825816
};
826817

lib/gc/Transforms/Pipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
4444
// todo: tensor constant propagation pass
4545
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
4646
// Fine-grain fusion pass
47-
pm.addNestedPass<func::FuncOp>(createFineGrainedFusion());
47+
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
4848
// todo: lower linalg to arith/math on virtual vector pass
4949

5050
// REMOVE this pass after the above passes are added. Currently we add this

test/gc/Transform/fine-grained-fusion.mlir renamed to test/gc/Transform/iterative-tiling-and-fusion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: gc-opt --split-input-file -fine-grained-fusion %s --cse
1+
// RUN: gc-opt --split-input-file -iterative-tiling-and-fusion %s --cse
22

33
module attributes {
44
dlti.target_system_spec = #dlti.target_system_spec<

0 commit comments

Comments
 (0)