1212namespace mlir {
1313namespace gc {
1414
15- std::ostream &operator <<(std::ostream &ss, const TensorLayout &layout) {
16- SmallVector<int64_t > outerAxis = layout.getOuterAxis ();
17- SmallVector<int64_t > innerAxis = layout.getInnerAxis ();
18- SmallVector<OpFoldResult> tileSizes = layout.getTileSizes ();
15+ #define DEBUG_TYPE " global-analysis"
16+
17+ llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
18+ const TensorLayout &layoutCache) {
19+ SmallVector<int64_t > outerAxis = layoutCache.getOuterAxis ();
20+ SmallVector<int64_t > innerAxis = layoutCache.getInnerAxis ();
21+ SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes ();
1922 ss << " [" ;
2023 for (size_t i = 0 ; i < outerAxis.size (); ++i) {
2124 if (i != 0 ) {
@@ -43,21 +46,21 @@ std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout) {
4346 return ss;
4447}
4548
46- bool TensorLayout::operator ==(const TensorLayout &layout ) {
47- return (this ->OuterAxis == layout .getOuterAxis ()) &&
48- (this ->InnerAxis == layout .getInnerAxis ()) &&
49- (this ->TileSizes == layout .getTileSizes ());
49+ bool TensorLayout::operator ==(const TensorLayout &layoutCache ) {
50+ return (this ->OuterAxis == layoutCache .getOuterAxis ()) &&
51+ (this ->InnerAxis == layoutCache .getInnerAxis ()) &&
52+ (this ->TileSizes == layoutCache .getTileSizes ());
5053}
5154
52- std::ostream &operator <<(std::ostream &ss, const OperatorLayout &opLayout) {
53- ss << " operator has " << opLayout.getSupportedInputLayouts ().size ()
54- << " inputs; " << opLayout.getSupportedOutputLayouts ().size ()
55- << " outputs." << std::endl;
56- for (const auto &layout : opLayout.getSupportedInputLayouts ()) {
57- ss << " input layout: " << layout << std::endl;
55+ llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
56+ const OperatorLayout &opLayout) {
57+ for (auto &&[idx, layoutCache] :
58+ llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
59+ ss << " input " << idx << " 's layoutCache: " << layoutCache << " \n " ;
5860 }
59- for (const auto &layout : opLayout.getSupportedOutputLayouts ()) {
60- ss << " output layout: " << layout << std::endl;
61+ for (auto &&[idx, layoutCache] :
62+ llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
63+ ss << " output " << idx << " 's layoutCache: " << layoutCache << " \n " ;
6164 }
6265 return ss;
6366}
@@ -119,7 +122,6 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
119122static FailureOr<TensorLayout>
120123inferTargetLayout (TensorLayout layoutBase,
121124 const DenseMap<int64_t , int64_t > &indexMap) {
122- int64_t dimDifference = indexMap.size () - layoutBase.getTensorRank ();
123125 SmallVector<int64_t > baseOuterAxis = layoutBase.getOuterAxis ();
124126 SmallVector<int64_t > baseInnerAxis = layoutBase.getInnerAxis ();
125127 SmallVector<OpFoldResult> baseTileSizes = layoutBase.getTileSizes ();
@@ -153,38 +155,24 @@ inferTargetLayout(TensorLayout layoutBase,
153155
154156GlobalAnalysis::GlobalAnalysis (Operation *root) {
155157 root->walk ([&](Operation *op) {
158+ // get input layouts
159+ LLVM_DEBUG (llvm::dbgs ()
160+ << " Inferring layoutCache of op: " << op->getName () << " \n " );
156161 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
157- // get input layouts
158- std::cout << std::endl;
159- std::cout << " ----------------------------------" << std::endl;
160- linalgOp.getOperation ()->getName ().print (llvm::errs ());
161- std::cout << std::endl;
162- std::cout << " ----------------------------------" << std::endl;
163- std::cout << std::endl;
164- SmallVector<AffineMap> indexing_maps = linalgOp.getIndexingMapsArray ();
165162 auto curInputs = linalgOp.getDpsInputOperands ();
166163 auto curResults = linalgOp.getOperation ()->getResults ();
167-
168164 // ---------------- Get Current Input Layouts -------------------
169- // get current input layouts
170- std::cout << " ----- printing ground-truth input layouts -----"
171- << std::endl;
172165 SmallVector<TensorLayout> curInputLayouts;
173166 for (auto input : curInputs) {
174167 auto parent = input->get ().getDefiningOp ();
175- if (layout .find (parent) != layout .end ()) {
168+ if (layoutCache .find (parent) != layoutCache .end ()) {
176169 // TODO(yifei): it is not always 0 here
177- curInputLayouts.push_back (layout [parent].getOutputLayout (0 ));
170+ curInputLayouts.push_back (layoutCache [parent].getOutputLayout (0 ));
178171 } else {
179172 curInputLayouts.push_back (TensorLayout::createPlainLayout (
180173 linalgOp.getMatchingIndexingMap (input).getNumResults ()));
181174 }
182175 }
183- // debug info
184- for (auto layout : curInputLayouts) {
185- std::cout << " layout: " << layout << std::endl;
186- }
187-
188176 // ------ Get Current Op's Suggested Layout & Do Propagation ------
189177 IRRewriter rewriter (linalgOp);
190178 if (mlir::linalg::isaContractionOpInterface (linalgOp)) {
@@ -193,38 +181,33 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
193181 // curInputLayouts);
194182
195183 // hardcode one for now
196- // A side layout , [0, 1, 0, 1]; {32, 32}
184+ // A side layoutCache , [0, 1, 0, 1]; {32, 32}
197185 TensorLayout A_layout (
198186 {0 , 1 }, {0 , 1 },
199187 SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
200188 rewriter.getIndexAttr (32 )});
201- // B side layout , [1, 0, 0, 1]; {32, 32}
189+ // B side layoutCache , [1, 0, 0, 1]; {32, 32}
202190 TensorLayout B_layout (
203191 {1 , 0 }, {0 , 1 },
204192 SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
205193 rewriter.getIndexAttr (32 )});
206- // C side layout , [0, 1, 0, 1]; {32, 32}
194+ // C side layoutCache , [0, 1, 0, 1]; {32, 32}
207195 TensorLayout C_layout (
208196 {0 , 1 }, {0 , 1 },
209197 SmallVector<OpFoldResult>{rewriter.getIndexAttr (32 ),
210198 rewriter.getIndexAttr (32 )});
211199 OperatorLayout suggestedLayout ({A_layout, B_layout}, {C_layout});
212- layout [linalgOp] = suggestedLayout;
200+ layoutCache [linalgOp] = suggestedLayout;
213201 } else {
214202 SmallVector<TensorLayout> inputLayouts, outputLayouts;
215203 inputLayouts.push_back (curInputLayouts[0 ]);
216204 // TODO(yifei): wisely choose the input format basis
217205 // Let's only refer to input[0] for now
218206 for (size_t i = 1 ; i < curInputs.size (); ++i) {
219- std::cout << " inferring indexing map relation" << std::endl;
220207 // getMatchingIndexingMap
221208 auto res = inferIndexingMapRelation (
222209 linalgOp.getMatchingIndexingMap (curInputs[0 ]),
223210 linalgOp.getMatchingIndexingMap (curInputs[i]));
224- for (auto tp : *res) {
225- std::cout << " target index: " << tp.first
226- << " maps to base index: " << tp.second << std::endl;
227- }
228211 TensorLayout inputLayout =
229212 *inferTargetLayout (curInputLayouts[0 ], *res);
230213 inputLayouts.push_back (inputLayout);
@@ -235,14 +218,66 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
235218 TensorLayout outputLayout =
236219 *inferTargetLayout (curInputLayouts[0 ], *res_out);
237220 outputLayouts.push_back (outputLayout);
238- for (auto tp : *res_out) {
239- std::cout << " target index: " << tp.first
240- << " maps to base index: " << tp.second << std::endl;
241- }
242221 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
243- layout[linalgOp] = suggestedLayout;
222+ layoutCache[linalgOp] = suggestedLayout;
223+ }
224+ } else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
225+ auto inputOperand = padOp.getSource ();
226+ auto inputRank =
227+ cast<ShapedType>(inputOperand.getType ()).getShape ().size ();
228+ auto parent = inputOperand.getDefiningOp ();
229+ TensorLayout curInputLayout =
230+ layoutCache.find (parent) != layoutCache.end ()
231+ ? layoutCache[parent].getOutputLayout (0 )
232+ : TensorLayout::createPlainLayout (inputRank);
233+ SmallVector<TensorLayout> inputLayouts{curInputLayout},
234+ outputLayouts{curInputLayout};
235+ OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
236+ layoutCache[padOp] = suggestedLayout;
237+ } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
238+ auto reassociation = expandShapeOp.getReassociation ();
239+ auto staticOutputShape = expandShapeOp.getStaticOutputShape ();
240+ auto parent = expandShapeOp.getSrc ().getDefiningOp ();
241+ auto inputShape = expandShapeOp.getSrcType ().getShape ();
242+ TensorLayout curInputLayout =
243+ layoutCache.find (parent) != layoutCache.end ()
244+ ? layoutCache[parent].getOutputLayout (0 )
245+ : TensorLayout::createPlainLayout (inputShape.size ());
246+ DenseMap<int64_t , int64_t > outputInputIdxMapping, inputOutputIndexMapping;
247+ int64_t accumulationOffset = 0 ;
248+ for (int64_t i = 0 ; i < static_cast <int64_t >(reassociation.size ()); ++i) {
249+ auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
250+ for (int64_t j = 0 ; j < static_cast <int64_t >(subReassociation.size ());
251+ ++j) {
252+ if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
253+ outputInputIdxMapping[accumulationOffset + j] = i;
254+ inputOutputIndexMapping[i] = accumulationOffset + j;
255+ }
256+ }
257+ accumulationOffset += subReassociation.size ();
258+ }
259+ auto inputOuterAxis = curInputLayout.getOuterAxis ();
260+ auto inputInnerAxis = curInputLayout.getInnerAxis ();
261+ int64_t startIdx = 0 ;
262+ SmallVector<int64_t > outputOuterAxis, outputInnerAxis;
263+ for (int64_t i = 0 ; i < static_cast <int64_t >(staticOutputShape.size ());
264+ ++i) {
265+ if (outputInputIdxMapping.find (i) != outputInputIdxMapping.end ()) {
266+ outputOuterAxis.push_back (inputOuterAxis[outputInputIdxMapping[i]]);
267+ } else {
268+ outputOuterAxis.push_back (startIdx++);
269+ }
270+ }
271+ for (int64_t i = 0 ; i < static_cast <int64_t >(inputInnerAxis.size ());
272+ ++i) {
273+ outputInnerAxis.push_back (inputOutputIndexMapping[inputInnerAxis[i]]);
244274 }
245- } else if (isa<tensor::PadOp>(op) || isa<tensor::ExpandShapeOp>(op)) {
275+ TensorLayout outputLayout (outputOuterAxis, outputInnerAxis,
276+ curInputLayout.getTileSizes ());
277+ SmallVector<TensorLayout> inputLayouts{curInputLayout},
278+ outputLayouts{outputLayout};
279+ OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
280+ layoutCache[expandShapeOp] = suggestedLayout;
246281 }
247282 });
248283}
0 commit comments