1
- // ===-- FineGrainedFusion .cpp - Fine-Grained Fusion ----------- --*- C++ -*-===//
1
+ // ===-- IterativeTilingAndFusion .cpp - Iterative Tiling+ Fusion --*- C++ -*-===//
2
2
//
3
3
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
33
33
34
34
namespace mlir {
35
35
namespace gc {
36
- #define GEN_PASS_DEF_FINEGRAINEDFUSION
36
+ #define GEN_PASS_DEF_ITERATIVETILINGANDFUSION
37
37
#include " gc/Transforms/Passes.h.inc"
38
38
39
39
static FailureOr<tensor::ExtractSliceOp>
@@ -215,6 +215,15 @@ alreadyTiledOpFilter(RewriterBase &rewriter,
215
215
return failure (defOrUse.ownerOp ->use_empty ());
216
216
}
217
217
218
+ static LogicalResult
219
+ NonContractionOpFilter (RewriterBase &rewriter,
220
+ OffsetSizeAndStrideOpInterface candidate,
221
+ CandidateDefOrUse defOrUse) {
222
+ // Currently this pass focuses on fine-grained fusion, which does not expect
223
+ // two consecutive contraction ops.
224
+ return failure (isa<mlir::linalg::ContractionOpInterface>(defOrUse.ownerOp ));
225
+ }
226
+
218
227
static LogicalResult
219
228
SingleCandidateInBlockFilter (RewriterBase &rewriter,
220
229
OffsetSizeAndStrideOpInterface candidate,
@@ -289,7 +298,7 @@ struct CandidateSliceFilterPipeLine
289
298
290
299
SmallVector<CandidateSliceFilter> getDefaultPipeLine () {
291
300
return SmallVector<CandidateSliceFilter>{
292
- alreadyTiledOpFilter, noTilingOnReductionFilter,
301
+ alreadyTiledOpFilter, NonContractionOpFilter, noTilingOnReductionFilter,
293
302
exactTilingOnPackUnPackFilter, SingleCandidateInBlockFilter};
294
303
}
295
304
@@ -499,9 +508,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
499
508
// candidate slice in avoid of conflict with subsequent
500
509
// `tileAndFuseConsumerOfSlice` get nest loops between next candidate
501
510
// sliceOp and tiled producer.
502
- auto region = outerLoops.front ()->getParentRegion ();
503
- (void )mlir::eraseUnreachableBlocks (rewriter, {*region});
504
- (void )mlir::runRegionDCE (rewriter, {*region});
511
+ (void )mlir::simplifyRegions (rewriter,
512
+ {*outerLoops.front ()->getParentRegion ()});
505
513
}
506
514
}
507
515
if (fusedResultList.empty ()) {
@@ -575,7 +583,7 @@ static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
575
583
// 2. check single one tiling interface in loop body
576
584
auto walkResult = forOp->walk ([&targetOp](TilingInterface op) {
577
585
// some special op maybe already deal with in template
578
- if (isa<linalg::FillOp>(op))
586
+ if (isa<linalg::FillOp, linalg::CopyOp >(op))
579
587
return WalkResult::skip ();
580
588
return op != targetOp ? WalkResult::interrupt () : WalkResult::advance ();
581
589
});
@@ -632,10 +640,6 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
632
640
}
633
641
}
634
642
635
- struct IterativeFusionOptions {
636
- bool useCostModel = false ;
637
- };
638
-
639
643
struct SystemDesc {
640
644
// get runtime OMP_NUM_THREADS
641
645
uint32_t getNumThreads () {
@@ -692,8 +696,9 @@ struct SystemDesc {
692
696
MLIRContext *ctx;
693
697
};
694
698
695
- void iterativeTilingAndFusion (RewriterBase &rewriter, func::FuncOp &f,
696
- const IterativeFusionOptions &fuseOptions) {
699
+ void iterativeTilingAndFusionUntilExhaustion (
700
+ RewriterBase &rewriter, func::FuncOp &f,
701
+ const CandidateSliceOptions &sliceOptions) {
697
702
// Collect untiled and tiled ops respectively
698
703
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
699
704
@@ -731,29 +736,6 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
731
736
return !singleTiledOpInLoop.empty ();
732
737
};
733
738
734
- SystemDesc sysDesc (f->getParentOfType <ModuleOp>());
735
- // Flexible options to control which candidate slice would be selected from
736
- // the view of both validity and performance.
737
- CandidateSliceOptions sliceOptions;
738
- // Since most filters regarding to validity have already been built-in
739
- // enabled. Users could focus on performance related filters, a.k.a. cost
740
- // model.
741
- if (fuseOptions.useCostModel ) {
742
- // Customized filter by cost model.
743
- CandidateSliceFilter costModelFilter =
744
- [&sysDesc](RewriterBase &rewriter,
745
- OffsetSizeAndStrideOpInterface candidate,
746
- CandidateDefOrUse defOrUse) -> LogicalResult {
747
- // Get cache size
748
- size_t l2CacheSize = sysDesc.getCacheSize (2 );
749
- FailureOr<int64_t > tileSizeProduct =
750
- computeTileSizeProductOfCandidate (candidate);
751
- return success (succeeded (tileSizeProduct) &&
752
- (*tileSizeProduct <= (int64_t )l2CacheSize));
753
- };
754
- sliceOptions.addFilter (costModelFilter);
755
- }
756
-
757
739
// Iterative tiling and fusion until exhaustion.
758
740
while (collectUnTiledOps ()) {
759
741
// If existing tiled op before tiling.
@@ -768,12 +750,8 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
768
750
changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
769
751
rewriter, tiledOp, sliceOptions));
770
752
});
771
- if (!changed) {
772
- // If no new fusion happens, terminate iteration.
773
- break ;
774
- } else {
775
- (void )mlir::eraseUnreachableBlocks (rewriter, {f.getRegion ()});
776
- (void )mlir::runRegionDCE (rewriter, {f.getRegion ()});
753
+ if (changed) {
754
+ (void )mlir::simplifyRegions (rewriter, {f.getRegion ()});
777
755
}
778
756
} else {
779
757
// Auto tiling with default tile size if no tiled op found. Follow tiling
@@ -798,29 +776,42 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
798
776
}
799
777
}
800
778
801
- struct FineGrainedFusion
802
- : public impl::FineGrainedFusionBase<FineGrainedFusion > {
803
- using FineGrainedFusionBase::FineGrainedFusionBase ;
779
+ struct IterativeTilingAndFusion
780
+ : public impl::IterativeTilingAndFusionBase<IterativeTilingAndFusion > {
781
+ using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase ;
804
782
805
783
public:
806
784
void runOnOperation () final {
807
785
auto &ctx = getContext ();
808
- {
809
- // Get funcOp
810
- func::FuncOp func = getOperation ();
811
- // Get rewriter
812
- IRRewriter rewriter (&ctx);
813
- // Run iterative fusion
814
- iterativeTilingAndFusion (rewriter, func,
815
- IterativeFusionOptions{useCostModel});
816
- }
817
-
818
- {
819
- RewritePatternSet patternSet (&ctx);
820
- if (failed (applyPatternsAndFoldGreedily (getOperation (),
821
- std::move (patternSet))))
822
- signalPassFailure ();
786
+ // Get funcOp
787
+ func::FuncOp func = getOperation ();
788
+ // Get system descriptor
789
+ SystemDesc sysDesc (func->getParentOfType <ModuleOp>());
790
+ // Flexible options to control which candidate slice would be selected from
791
+ // the view of both validity and performance.
792
+ CandidateSliceOptions sliceOptions;
793
+ // Since most filters regarding to validity have already been built-in
794
+ // enabled. Users could focus on performance related filters, a.k.a. cost
795
+ // model. E.g.
796
+ if (useCostModel) {
797
+ // Customized filter by cost model.
798
+ CandidateSliceFilter costModelFilter =
799
+ [&sysDesc](RewriterBase &rewriter,
800
+ OffsetSizeAndStrideOpInterface candidate,
801
+ CandidateDefOrUse defOrUse) -> LogicalResult {
802
+ // Get cache size
803
+ size_t l2CacheSize = sysDesc.getCacheSize (2 );
804
+ FailureOr<int64_t > tileSizeProduct =
805
+ computeTileSizeProductOfCandidate (candidate);
806
+ return success (succeeded (tileSizeProduct) &&
807
+ (*tileSizeProduct <= (int64_t )l2CacheSize));
808
+ };
809
+ sliceOptions.addFilter (costModelFilter);
823
810
}
811
+ // Get rewriter
812
+ IRRewriter rewriter (&ctx);
813
+ // Run iterative fusion
814
+ iterativeTilingAndFusionUntilExhaustion (rewriter, func, sliceOptions);
824
815
}
825
816
};
826
817
0 commit comments