77// ===----------------------------------------------------------------------===//
88
99#include " gc/Analysis/MatmulConfigAnalysis.h"
10+ #include " gc/Analysis/TargetDescriptionAnalysis.h"
1011#include < limits>
1112#include < llvm/Support/Debug.h>
1213
@@ -64,7 +65,8 @@ getCandidate(uint32_t num, uint32_t floor,
6465}
6566
6667// check if the threads are valid
67- bool validateThreads (ArrayRef<uint32_t > threads, SystemDesc &sysDesc) {
68+ bool validateThreads (ArrayRef<uint32_t > threads,
69+ CPUTargetDescriptionAnalysis &sysDesc) {
6870 uint32_t numThreads = sysDesc.getNumThreads ();
6971 uint32_t actualThreads = 1U ;
7072 for (uint32_t t : threads)
@@ -77,24 +79,25 @@ bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
7779double vectorRegEfficiencyCost (linalg::LinalgOp &linalgOp,
7880 ArrayRef<uint32_t > shape,
7981 const MatmulConfig &config,
80- SystemDesc &sysDesc) {
82+ CPUTargetDescriptionAnalysis &sysDesc) {
8183 size_t dtypeSize = DataLayout ().getTypeSizeInBits (
8284 ShapeAdaptor (linalgOp.getDpsInputs ()[1 ].getType ()).getElementType ());
83- size_t maxVectorLength = sysDesc.getMaxVectorLength () / dtypeSize;
85+ size_t maxVectorWidth = sysDesc.getMaxVectorWidth () / dtypeSize;
8486 // TODO: take matrix register like amx into account
85- double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength ) %
86- maxVectorLength * 1.0 / config.innerMostMBlock +
87- (maxVectorLength - config.innerMostKBlock % maxVectorLength ) %
88- maxVectorLength * 1.0 / config.innerMostKBlock +
89- (maxVectorLength - config.innerMostNBlock % maxVectorLength ) %
90- maxVectorLength * 1.0 / config.innerMostNBlock ;
87+ double cost = (maxVectorWidth - config.innerMostMBlock % maxVectorWidth ) %
88+ maxVectorWidth * 1.0 / config.innerMostMBlock +
89+ (maxVectorWidth - config.innerMostKBlock % maxVectorWidth ) %
90+ maxVectorWidth * 1.0 / config.innerMostKBlock +
91+ (maxVectorWidth - config.innerMostNBlock % maxVectorWidth ) %
92+ maxVectorWidth * 1.0 / config.innerMostNBlock ;
9193 return cost;
9294}
9395
9496// calculate the cost of the workload balance
9597double workloadBalancedCost (linalg::LinalgOp &linalgOp,
9698 ArrayRef<uint32_t > shape,
97- const MatmulConfig &config, SystemDesc &sysDesc) {
99+ const MatmulConfig &config,
100+ CPUTargetDescriptionAnalysis &sysDesc) {
98101 if (shape.size () < 3 ) {
99102 // Has an invalid shape
100103 return 0 ;
@@ -118,7 +121,7 @@ double workloadBalancedCost(linalg::LinalgOp &linalgOp,
118121double memoryConsumptionOnThreadCost (linalg::LinalgOp &linalgOp,
119122 ArrayRef<uint32_t > shape,
120123 const MatmulConfig &config,
121- SystemDesc &sysDesc) {
124+ CPUTargetDescriptionAnalysis &sysDesc) {
122125 if (shape.size () < 3 ) {
123126 // Has an invalid shape
124127 return 0 ;
@@ -141,7 +144,7 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp,
141144double computationIntensityOnL2Cache (linalg::LinalgOp &linalgOp,
142145 ArrayRef<uint32_t > shape,
143146 const MatmulConfig &config,
144- SystemDesc &sysDesc) {
147+ CPUTargetDescriptionAnalysis &sysDesc) {
145148 double fullLoadRatio = 0.7 ;
146149 uint32_t L2Cache = sysDesc.getCacheSize (2 );
147150 size_t dtypeSize = DataLayout ().getTypeSize (
@@ -157,16 +160,17 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
157160 return 1 / computationIntensity;
158161}
159162
160- using CostModelFn =
161- std::function< double ( linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
162- MatmulConfig cfg, SystemDesc &sysDesc)>;
163+ using CostModelFn = std::function< double (
164+ linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape, MatmulConfig cfg ,
165+ CPUTargetDescriptionAnalysis &sysDesc)>;
163166
164167// filter the config by the cost model
165168std::vector<MatmulConfig>
166169filterConfigByCostModel (ArrayRef<MatmulConfig> configs,
167170 linalg::LinalgOp &linalgOp, ArrayRef<uint32_t > shape,
168- SystemDesc &sysDesc, const CostModelFn &costModel,
169- float preserveRatio = 0.5 , float threshold = -1 ) {
171+ CPUTargetDescriptionAnalysis &sysDesc,
172+ const CostModelFn &costModel, float preserveRatio = 0.5 ,
173+ float threshold = -1 ) {
170174 std::vector<MatmulConfig> result;
171175 std::vector<float > costs;
172176 std::vector<size_t > idx;
@@ -196,7 +200,7 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
196200
197201// prepare the config candidates
198202std::vector<MatmulConfig>
199- prepareConfigCandidates (Operation *root, SystemDesc &sysDesc,
203+ prepareConfigCandidates (Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
200204 ArrayRef<uint32_t > shape,
201205 ArrayRef<uint32_t > givenInnermostBlock) {
202206 if (shape.size () < 3 ) {
@@ -347,7 +351,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
347351// previous matmul
348352MatmulConfigAnalysis::MatmulConfigAnalysis (Operation *root) {
349353 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
350- SystemDesc sysDesc (root-> getParentOfType <ModuleOp>() );
354+ CPUTargetDescriptionAnalysis sysDesc (root);
351355 SmallVector<SmallVector<DimType>> oprandDimType =
352356 *getOprandDimType (linalgOp);
353357 // get the origin M,N,K size
0 commit comments