@@ -44,11 +44,10 @@ getCandidate(uint32_t num, uint32_t floor,
4444 // factor
4545 std::vector<uint32_t > candidates;
4646 uint32_t upperbound = std::min (num, ceil);
47- for (uint32_t i = floor; i <= upperbound; i++) {
48- if (num % i == 0 ) {
47+ for (uint32_t i = floor; i <= upperbound; i++)
48+ if (num % i == 0 )
4949 candidates.push_back (i);
50- }
51- }
50+
5251 // the pow of 2
5352 uint32_t candidate = 1U ;
5453 while (candidate < floor)
@@ -68,9 +67,8 @@ getCandidate(uint32_t num, uint32_t floor,
6867bool validateThreads (ArrayRef<uint32_t > threads, SystemDesc &sysDesc) {
6968 uint32_t numThreads = sysDesc.getNumThreads ();
7069 uint32_t actualThreads = 1U ;
71- for (uint32_t t : threads) {
70+ for (uint32_t t : threads)
7271 actualThreads *= t;
73- }
7472 return actualThreads == numThreads;
7573}
7674
@@ -154,9 +152,8 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
154152 config.NBlock * config.KBlock +
155153 config.MBlock * config.KBlock ;
156154 double computationIntensity = FLOPS / memoryConsumption;
157- if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) {
155+ if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio)
158156 computationIntensity /= outOfCachePenalty;
159- }
160157 return 1 / computationIntensity;
161158}
162159
@@ -183,19 +180,17 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
183180 double thresholdCost = costs[idx[(size_t )(preserveRatio * configs.size ())]];
184181 thresholdCost =
185182 threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
186- for (const auto &i : idx) {
187- if (costs[i] <= thresholdCost) {
183+ for (const auto &i : idx)
184+ if (costs[i] <= thresholdCost)
188185 result.push_back (configs[i]);
189- }
190- }
186+
191187 LLVM_DEBUG (llvm::dbgs () << " thresholdCost is: " << thresholdCost
192188 << " \n best with cost: " << costs[idx[0 ]] << " \n "
193189 << configs[idx[0 ]] << " \n worst with cost: "
194190 << costs[idx[configs.size () - 1 ]] << " \n "
195191 << configs[idx[configs.size () - 1 ]] << " \n " );
196- if (result.empty ()) {
192+ if (result.empty ())
197193 result = configs;
198- }
199194 return result;
200195}
201196
@@ -248,27 +243,23 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
248243 for (uint32_t MThreads : MThreadsCandidates) {
249244 for (uint32_t NThreads : NThreadsCandidates) {
250245 for (uint32_t KThreads : KThreadsCandidates) {
251- if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc)) {
246+ if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc))
252247 continue ;
253- }
254248 for (uint32_t MBlock : MBlockCandidates) {
255249 for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
256250 if (MBlock % innerMostMBlock != 0 ||
257- shape[0 ] % innerMostMBlock != 0 ) {
251+ shape[0 ] % innerMostMBlock != 0 )
258252 continue ;
259- }
260253 for (uint32_t NBlock : NBlockCandidates) {
261254 for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
262255 if (NBlock % innerMostNBlock != 0 ||
263- shape[1 ] % innerMostNBlock != 0 ) {
256+ shape[1 ] % innerMostNBlock != 0 )
264257 continue ;
265- }
266258 for (uint32_t KBlock : KBlockCandidates) {
267259 for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
268260 if (KBlock % innerMostKBlock != 0 ||
269- shape[2 ] % innerMostKBlock != 0 ) {
261+ shape[2 ] % innerMostKBlock != 0 )
270262 continue ;
271- }
272263 MatmulConfig config{
273264 MThreads, NThreads, KThreads,
274265 MBlock, NBlock, KBlock,
@@ -293,14 +284,12 @@ bool validateConfig(const MatmulConfig &cfg) {
293284 if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
294285 cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
295286 cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
296- cfg.innerMostKBlock <= 0 ) {
287+ cfg.innerMostKBlock <= 0 )
297288 return false ;
298- }
299289 if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
300290 cfg.NBlock % cfg.innerMostNBlock != 0 ||
301- cfg.KBlock % cfg.innerMostKBlock != 0 ) {
291+ cfg.KBlock % cfg.innerMostKBlock != 0 )
302292 return false ;
303- }
304293 return true ;
305294}
306295
@@ -371,19 +360,16 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
371360 uint32_t M = 1U , N = 1U , K = 1U ;
372361 for (auto &&[s, dimType] :
373362 llvm::zip (linalgOp.getShape (linalgOp.getDpsInputOperand (0 )),
374- oprandDimType[0 ])) {
375- if (dimType == DimType::M) {
363+ oprandDimType[0 ]))
364+ if (dimType == DimType::M)
376365 M *= s;
377- }
378- }
379366 for (auto &&[s, dimType] :
380367 llvm::zip (linalgOp.getShape (linalgOp.getDpsInputOperand (1 )),
381368 oprandDimType[1 ])) {
382- if (dimType == DimType::N) {
369+ if (dimType == DimType::N)
383370 N *= s;
384- } else if (dimType == DimType::K) {
371+ else if (dimType == DimType::K)
385372 K *= s;
386- }
387373 }
388374
389375 // innermost Block, if the layout is blockied layout, the innermost block
@@ -395,30 +381,30 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
395381 SmallVector<uint32_t > givenInnermostBlock;
396382 if (MDimTypeIdx.size () > 1 ) {
397383 config.innerMostMBlock = 1 ;
398- for (size_t i = 1UL ; i < MDimTypeIdx. size (); i++) {
399- config. innerMostMBlock *=
400- linalgOp. getShape (linalgOp. getDpsInputOperand ( 0 ))[MDimTypeIdx[i]];
401- }
384+ for (auto &&[i, d] : llvm::enumerate (MDimTypeIdx))
385+ if (i != 0 )
386+ config. innerMostMBlock *=
387+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 0 ))[d];
402388 givenInnermostBlock.push_back (config.innerMostMBlock );
403389 } else {
404390 givenInnermostBlock.push_back (0 );
405391 }
406392 if (NDimTypeIdx.size () > 1 ) {
407393 config.innerMostNBlock = 1 ;
408- for (size_t i = 1UL ; i < NDimTypeIdx. size (); i++) {
409- config. innerMostNBlock *=
410- linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[NDimTypeIdx[i]];
411- }
394+ for (auto &&[i, d] : llvm::enumerate (NDimTypeIdx))
395+ if (i != 0 )
396+ config. innerMostNBlock *=
397+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[d];
412398 givenInnermostBlock.push_back (config.innerMostNBlock );
413399 } else {
414400 givenInnermostBlock.push_back (0 );
415401 }
416402 if (KDimTypeIdx.size () > 1 ) {
417403 config.innerMostKBlock = 1 ;
418- for (size_t i = 1UL ; i < KDimTypeIdx. size (); i++) {
419- config. innerMostKBlock *=
420- linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[KDimTypeIdx[i]];
421- }
404+ for (auto &&[i, d] : llvm::enumerate (KDimTypeIdx))
405+ if (i != 0 )
406+ config. innerMostKBlock *=
407+ linalgOp. getShape (linalgOp. getDpsInputOperand ( 1 ))[d];
422408 givenInnermostBlock.push_back (config.innerMostKBlock );
423409 } else {
424410 givenInnermostBlock.push_back (0 );
@@ -444,13 +430,11 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
444430 SmallVector<uint32_t > shape = {M, N, K};
445431 std::vector<MatmulConfig> configCandidates =
446432 prepareConfigCandidates (root, sysDesc, shape, givenInnermostBlock);
447- for (auto &&[fn, name, threshold] : costModelList) {
433+ for (auto &&[fn, name, threshold] : costModelList)
448434 configCandidates = filterConfigByCostModel (
449435 configCandidates, linalgOp, shape, sysDesc, fn, 0.5 , threshold);
450- }
451- if (!configCandidates.empty ()) {
436+ if (!configCandidates.empty ())
452437 config = configCandidates[0 ];
453- }
454438 }
455439
456440 LLVM_DEBUG (llvm::dbgs ()
0 commit comments