@@ -44,13 +44,15 @@ bool TensorLayout::operator==(const TensorLayout &layout) {
4444
4545llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
4646 const OperatorLayout &opLayout) {
47- for (auto &&[idx, layoutCache] :
48- llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
49- ss << " input " << idx << " 's layout: " << layoutCache << " \n " ;
47+ if (!opLayout.getSupportedInputLayouts ().empty ()) {
48+ ss << " Input layouts: " ;
49+ llvm::interleave (opLayout.getSupportedInputLayouts (), ss, " ; " );
50+ ss << " . " ;
5051 }
51- for (auto &&[idx, layoutCache] :
52- llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
53- ss << " output " << idx << " 's layout: " << layoutCache << " \n " ;
52+ if (!opLayout.getSupportedOutputLayouts ().empty ()) {
53+ ss << " Output layouts: " ;
54+ llvm::interleave (opLayout.getSupportedOutputLayouts (), ss, " ; " );
55+ ss << " . " ;
5456 }
5557 return ss;
5658}
@@ -217,8 +219,6 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
217219GlobalAnalysis::GlobalAnalysis (Operation *root) {
218220 root->walk ([&](Operation *op) {
219221 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
220- LLVM_DEBUG (llvm::dbgs ()
221- << " Inferring layout of op: " << op->getName () << " \n " );
222222 auto curInputs = linalgOp.getDpsInputOperands ();
223223 auto curResults = linalgOp.getOperation ()->getResults ();
224224 // ---------------- Get Current Input Layouts -------------------
@@ -277,8 +277,11 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
277277 rewriter.getIndexAttr (iin)});
278278 OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
279279 layoutCache[linalgOp] = suggestedLayout;
280+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
281+ << " is: " << suggestedLayout << " \n " );
280282 } else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
281- !mlir::linalg::isaConvolutionOpInterface (linalgOp) &&
283+ !isa<linalg::ConvolutionOpInterface>(
284+ linalgOp.getOperation ()) &&
282285 !supportedContractionNamedOpList (linalgOp)) {
283286 // infer layout for non-contraction/non-convolution linalg named ops
284287 // and linalg generic ops
@@ -311,6 +314,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
311314 outputLayouts.push_back (outputLayout);
312315 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
313316 layoutCache[linalgOp] = suggestedLayout;
317+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
318+ << " is: " << suggestedLayout << " \n " );
314319 }
315320 } else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
316321 auto inputOperand = padOp.getSource ();
@@ -325,6 +330,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
325330 outputLayouts{curInputLayout};
326331 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
327332 layoutCache[padOp] = suggestedLayout;
333+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
334+ << " is: " << suggestedLayout << " \n " );
328335 } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
329336 SmallVector<ReassociationIndices> reassocIndices =
330337 expandShapeOp.getReassociationIndices ();
@@ -343,8 +350,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
343350 ArrayRef<int64_t > innerDimsPos = curInputLayout.getInnerAxis ();
344351 ArrayRef<int64_t > outerDimsPerm = curInputLayout.getOuterAxis ();
345352 SmallVector<int64_t > projectedInnerDimsPos =
346- projectToInnerMostNonUnitDimsPos (curInputLayout. getInnerAxis () ,
347- reassocIndices, staticOutputShape);
353+ projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices ,
354+ staticOutputShape);
348355
349356 if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, staticOutputShape,
350357 innerTileSizes)) {
@@ -362,6 +369,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
362369 outputLayouts{outputLayout};
363370 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
364371 layoutCache[expandShapeOp] = suggestedLayout;
372+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
373+ << " is: " << suggestedLayout << " \n " );
365374 }
366375 return WalkResult::advance ();
367376 });
0 commit comments