1+ // ===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===//
2+ //
3+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+
9+ #ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10+ #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11+
12+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
13+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
14+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
15+ #include " mlir/Pass/Pass.h"
16+ #include " mlir/Support/LLVM.h"
17+ #include " llvm/ADT/DenseMap.h"
18+ #include < llvm/Support/Debug.h>
19+ #include < memory>
20+ #include < numeric>
21+
22+ namespace mlir {
23+ namespace gc {
24+
25+ using namespace mlir ;
26+
27+ struct SystemDesc {
28+ // get runtime OMP_NUM_THREADS
29+ uint32_t getNumThreads () {
30+ char *numThreads = getenv (" OMP_NUM_THREADS" );
31+ if (numThreads) {
32+ return std::stoi (numThreads);
33+ }
34+ return 1 ;
35+ }
36+ // get cache size by cacheLevel
37+ size_t getCacheSize (uint8_t cacheLevel) {
38+ if (cacheLevel == 1 ) {
39+ char *cacheSize = getenv (" L1_CACHE_SIZE" );
40+ if (cacheSize) {
41+ return std::stoi (cacheSize);
42+ }
43+ } else if (cacheLevel == 2 ) {
44+ char *cacheSize = getenv (" L2_CACHE_SIZE" );
45+ if (cacheSize) {
46+ return std::stoi (cacheSize);
47+ }
48+ } else if (cacheLevel == 3 ) {
49+ char *cacheSize = getenv (" L3_CACHE_SIZE" );
50+ if (cacheSize) {
51+ return std::stoi (cacheSize);
52+ }
53+ }
54+ return 0 ;
55+ }
56+
57+ SmallVector<size_t > getContractionOperationMaxVectorLength () {
58+ return {512UL , 512UL };
59+ }
60+ };
61+
62+ struct MatmulConfig {
63+ uint32_t MBlock, NBlock, KBlock;
64+ uint32_t MThreads, NThreads, KThreads;
65+ uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66+ friend llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
67+ const MatmulConfig &config);
68+ };
69+
70+ enum DimType { Batch, M, N, K };
71+
72+ [[maybe_unused]] static SmallVector<unsigned >
73+ extractDimTypeIdx (ArrayRef<DimType> tyList, DimType ty) {
74+ SmallVector<unsigned > idxList;
75+ for (auto [idx, type] : llvm::enumerate (tyList)) {
76+ if (type == ty) {
77+ idxList.push_back (idx);
78+ }
79+ }
80+ return idxList;
81+ }
82+
83+ static FailureOr<SmallVector<SmallVector<DimType>>>
84+ getOprandDimType (linalg::LinalgOp &linalgOp) {
85+ if (isa<linalg::MatmulOp>(linalgOp)) {
86+ return SmallVector<SmallVector<DimType>>{
87+ SmallVector<DimType>{DimType::M, DimType::K},
88+ SmallVector<DimType>{DimType::K, DimType::N},
89+ SmallVector<DimType>{DimType::M, DimType::N}};
90+ } else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
91+ return SmallVector<SmallVector<DimType>>{
92+ SmallVector<DimType>{DimType::M, DimType::K},
93+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
94+ DimType::K},
95+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
96+ } else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
97+ return SmallVector<SmallVector<DimType>>{
98+ SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
99+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
100+ DimType::K},
101+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
102+ } else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
103+ return SmallVector<SmallVector<DimType>>{
104+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107+ }
108+ return failure ();
109+ }
110+
111+ struct MatmulConfigAnalysis {
112+ public:
113+ explicit MatmulConfigAnalysis (Operation *root);
114+ MatmulConfig getConfig () { return config; }
115+
116+ private:
117+ MatmulConfig config;
118+ };
119+
120+ } // namespace gc
121+ } // namespace mlir
122+
123+ #endif
0 commit comments