Skip to content

Commit d0c456f

Browse files
committed
add default tileSize option to pass
1 parent 111d276 commit d0c456f

File tree

2 files changed

+102
-50
lines changed

2 files changed

+102
-50
lines changed

include/gc/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
8585
Option<"useCostModel", "use-cost-model", "bool",
8686
/*default=*/"false",
8787
"Decide if enable cost model to control iterative fusion.">,
88+
ListOption<"defaultTileSize", "default-tile-size", "std::string",
89+
"Set default TileSize for the certain type of op, saying matmul:{32,32}">,
8890
];
8991
}
9092

lib/gc/Transforms/IterativeTilingAndFusion.cpp

Lines changed: 100 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Transforms/RegionUtils.h"
2929
#include <llvm/Support/Debug.h>
3030
#include <memory>
31+
#include <unordered_map>
3132

3233
#include "TilingUsingInterfaceX.h"
3334

@@ -601,45 +602,6 @@ static LogicalResult isSingleTiledOpInLoop(Operation *targetOp) {
601602
return success(walkResult.wasInterrupted());
602603
}
603604

604-
template <typename OpTy>
605-
static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
606-
// a. Check <OpTy>
607-
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
608-
return false;
609-
auto tilingInterfaceOp = cast<TilingInterface>(op);
610-
611-
scf::SCFTilingOptions options;
612-
// b. Get default tiling size
613-
SmallVector<utils::IteratorType> iteratorTypes =
614-
tilingInterfaceOp.getLoopIteratorTypes();
615-
616-
SmallVector<OpFoldResult> defaultTileSize(iteratorTypes.size(),
617-
rewriter.getIndexAttr(0));
618-
619-
for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
620-
// All outer non reduction loop should contribute parallelism. In another
621-
// word, all reduction dimensions should not be tiled.
622-
if (iterType == utils::IteratorType::parallel &&
623-
(en != iteratorTypes.size() - 1 ||
624-
llvm::count(iteratorTypes, utils::IteratorType::reduction))) {
625-
defaultTileSize[en] = rewriter.getIndexAttr(1);
626-
}
627-
}
628-
629-
options.setTileSizes(defaultTileSize);
630-
// c. Set loop type
631-
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
632-
// d. Use builtin tiling interface
633-
FailureOr<scf::SCFTilingResult> tilingResult =
634-
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
635-
if (succeeded(tilingResult)) {
636-
rewriter.replaceOp(op, tilingResult->replacements);
637-
return true;
638-
} else {
639-
return false;
640-
}
641-
}
642-
643605
struct SystemDesc {
644606
// get runtime OMP_NUM_THREADS
645607
uint32_t getNumThreads() {
@@ -696,9 +658,64 @@ struct SystemDesc {
696658
MLIRContext *ctx;
697659
};
698660

661+
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
662+
663+
template <typename OpTy>
664+
static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op,
665+
const OpTileSizeMap &tsMap) {
666+
// a. Check <OpTy>
667+
if (!isa<TilingInterface>(op) || !isa<OpTy>(op))
668+
return false;
669+
auto tilingInterfaceOp = cast<TilingInterface>(op);
670+
671+
scf::SCFTilingOptions options;
672+
// b. Get default tiling size
673+
SmallVector<utils::IteratorType> iteratorTypes =
674+
tilingInterfaceOp.getLoopIteratorTypes();
675+
676+
SmallVector<OpFoldResult> defaultTileSize;
677+
678+
std::string opName = op->getName().getStringRef().str();
679+
// Erase dialect name, such as Linalg or Tensor.
680+
opName.erase(0, opName.find(".") + 1);
681+
682+
if (tsMap.count(opName)) {
683+
SmallVector<int64_t> userDefaultTileSize = tsMap.find(opName)->second;
684+
defaultTileSize =
685+
getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize));
686+
} else {
687+
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
688+
for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) {
689+
// All outer non reduction loop should contribute parallelism. In another
690+
// word, all reduction dimensions should not be tiled.
691+
if (iterType == utils::IteratorType::parallel &&
692+
(en != iteratorTypes.size() - 1 ||
693+
llvm::count(iteratorTypes, utils::IteratorType::reduction))) {
694+
defaultTileSize[en] = rewriter.getIndexAttr(1);
695+
}
696+
}
697+
}
698+
// If the tile sizes are all zero, no tiling would happen.
699+
if (llvm::all_of(defaultTileSize, isZeroIndex))
700+
return false;
701+
702+
options.setTileSizes(defaultTileSize);
703+
// c. Set loop type
704+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
705+
// d. Use builtin tiling interface
706+
FailureOr<scf::SCFTilingResult> tilingResult =
707+
scf::tileUsingSCF(rewriter, tilingInterfaceOp, options);
708+
if (succeeded(tilingResult)) {
709+
rewriter.replaceOp(op, tilingResult->replacements);
710+
return true;
711+
} else {
712+
return false;
713+
}
714+
}
715+
699716
void iterativeTilingAndFusionUntilExhaustion(
700717
RewriterBase &rewriter, func::FuncOp &f,
701-
const CandidateSliceOptions &sliceOptions) {
718+
const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) {
702719
// Collect untiled and tiled ops respectively
703720
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
704721

@@ -756,26 +773,57 @@ void iterativeTilingAndFusionUntilExhaustion(
756773
} else {
757774
// Auto tiling with default tile size if no tiled op found. Follow tiling
758775
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
759-
SmallVector<std::function<bool(RewriterBase &, Operation *)>>
776+
SmallVector<std::function<bool(RewriterBase &, Operation *,
777+
const OpTileSizeMap &)>>
760778
priorityTilingPipeLine = {
761779
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
762780
defaultTilingOfType<mlir::linalg::ReduceOp>,
763781
defaultTilingOfType<mlir::linalg::LinalgOp>};
764-
if (llvm::all_of(
765-
priorityTilingPipeLine,
766-
[&rewriter, &unTiledOps](
767-
function_ref<bool(RewriterBase &, Operation *)> tilingFn) {
768-
return !llvm::any_of(unTiledOps,
769-
std::bind(tilingFn, std::ref(rewriter),
770-
std::placeholders::_1));
771-
})) {
782+
if (llvm::all_of(priorityTilingPipeLine,
783+
[&rewriter, &tsMap, &unTiledOps](
784+
function_ref<bool(RewriterBase &, Operation *,
785+
const OpTileSizeMap &)>
786+
tilingFn) {
787+
return !llvm::any_of(
788+
unTiledOps, std::bind(tilingFn, std::ref(rewriter),
789+
std::placeholders::_1,
790+
std::cref(tsMap)));
791+
})) {
772792
// If no op can be tiled
773793
break;
774794
}
775795
}
776796
}
777797
}
778798

799+
static OpTileSizeMap defaultTileSizeParser(ArrayRef<std::string> strArgs) {
800+
OpTileSizeMap tsMap;
801+
char warning[] =
802+
"Please follow correct argument format: opType:{ts1,ts2,...}";
803+
for (auto str : strArgs) {
804+
str.erase(llvm::remove_if(str, llvm::isSpace), str.end());
805+
size_t pos = str.find(":");
806+
if (pos == std::string::npos) {
807+
llvm_unreachable(warning);
808+
}
809+
std::string opType = str.substr(0, pos);
810+
std::string strTileSize = str.erase(0, pos + 1);
811+
if (strTileSize.size() <= 2 || strTileSize.front() != '{' ||
812+
strTileSize.back() != '}') {
813+
llvm_unreachable(warning);
814+
}
815+
strTileSize = strTileSize.substr(1, strTileSize.size() - 2);
816+
SmallVector<int64_t> intTileSize;
817+
while ((pos = strTileSize.find(",")) != std::string::npos) {
818+
intTileSize.push_back(std::stoi(strTileSize.substr(0, pos)));
819+
strTileSize.erase(0, pos + 1);
820+
}
821+
intTileSize.push_back(std::stoi(strTileSize));
822+
tsMap[opType] = intTileSize;
823+
}
824+
return tsMap;
825+
}
826+
779827
struct IterativeTilingAndFusion
780828
: public impl::IterativeTilingAndFusionBase<IterativeTilingAndFusion> {
781829
using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase;
@@ -808,10 +856,12 @@ struct IterativeTilingAndFusion
808856
};
809857
sliceOptions.addFilter(costModelFilter);
810858
}
859+
OpTileSizeMap tsMap = defaultTileSizeParser(defaultTileSize);
811860
// Get rewriter
812861
IRRewriter rewriter(&ctx);
813862
// Run iterative fusion
814-
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions);
863+
iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions,
864+
tsMap);
815865
}
816866
};
817867

0 commit comments

Comments
 (0)