@@ -55,13 +55,16 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
5555 return ss;
5656}
5757
58- // inferring the relationship of two indexing map
59- // j -> i, means j is represented as the same symbol as i
60- // we don't allow duplicate in symbols
61- // e.g. if 2 j corresponding to 1 i, then return failure
58+ // infer the relation between two indexing maps
59+ // returns target dim -> base dim, means target is the same as input
60+ // we don't allow duplication, e.g. 2 target corresponding to 1 base
6261static FailureOr<DenseMap<int64_t , int64_t >>
6362inferIndexingMapRelation (AffineMap indexingMapBase,
6463 AffineMap indexingMapTarget) {
64+ // symbols are not allowed to occur
65+ if (indexingMapBase.getNumSymbols () != 0 ||
66+ indexingMapTarget.getNumSymbols () != 0 )
67+ return failure ();
6568 DenseMap<int64_t , int64_t > res;
6669 ArrayRef<AffineExpr> resultsBase = indexingMapBase.getResults ();
6770 ArrayRef<AffineExpr> resultsTarget = indexingMapTarget.getResults ();
@@ -70,6 +73,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
7073 auto base = dyn_cast<AffineDimExpr>(resultsBase[i]);
7174 auto target = dyn_cast<AffineDimExpr>(resultsTarget[j]);
7275 if (base && target && base.getPosition () == target.getPosition ()) {
76+ // dim j already mapped to certain i
7377 if (res.find (j) != res.end ())
7478 return failure ();
7579 res[j] = i;
@@ -91,7 +95,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
9195 return res;
9296}
9397
94- // given j --> i and max rank of i , return i --> j
98+ // given target --> base and max rank of base , return base --> target
9599static DenseMap<int64_t , int64_t >
96100getReversedIndexMap (const DenseMap<int64_t , int64_t > &indexMap,
97101 size_t maxRank) {
@@ -109,7 +113,7 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
109113 return res;
110114}
111115
112- static FailureOr< TensorLayout>
116+ static TensorLayout
113117inferTargetLayout (TensorLayout layoutBase,
114118 const DenseMap<int64_t , int64_t > &indexMap) {
115119 SmallVector<int64_t > baseOuterAxis = layoutBase.getOuterAxis ();
@@ -177,6 +181,39 @@ getPackingAxis(int64_t numRank, bool transposed) {
177181 return std::make_pair (outerAxisPerm, innerAxisPos);
178182}
179183
184+ // copied from mlir
185+ static SmallVector<int64_t >
186+ projectToInnerMostNonUnitDimsPos (ArrayRef<int64_t > dimsPos,
187+ ArrayRef<ReassociationIndices> reassocIndices,
188+ ArrayRef<int64_t > targetShape) {
189+ SmallVector<int64_t > projectedDimsPos;
190+ for (auto pos : dimsPos) {
191+ // In the case all dims are unit, this will return the inner-most one.
192+ int64_t projectedPos = reassocIndices[pos].back ();
193+ for (auto i : llvm::reverse (reassocIndices[pos])) {
194+ int64_t dim = targetShape[i];
195+ if (dim > 1 || ShapedType::isDynamic (dim)) {
196+ projectedPos = i;
197+ break ;
198+ }
199+ }
200+ projectedDimsPos.push_back (projectedPos);
201+ }
202+ return projectedDimsPos;
203+ }
204+
205+ // / Check if all dims in dimsPos are divisible by the corresponding tile sizes.
206+ static bool isDimsDivisibleByTileSizes (ArrayRef<int64_t > dimsPos,
207+ ArrayRef<int64_t > shape,
208+ ArrayRef<int64_t > tileSizes) {
209+ for (auto [pos, tileSize] : llvm::zip_equal (dimsPos, tileSizes)) {
210+ int64_t dim = shape[pos];
211+ if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 )
212+ return false ;
213+ }
214+ return true ;
215+ }
216+
180217GlobalAnalysis::GlobalAnalysis (Operation *root) {
181218 root->walk ([&](Operation *op) {
182219 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
@@ -198,9 +235,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
198235 }
199236 // ------ Get Current Op's Suggested Layout & Do Propagation ------
200237 IRRewriter rewriter (linalgOp);
201- // TODO: extend to packed/vnni matmul ops
202238 if (supportedContractionNamedOpList (linalgOp)) {
203- // get input and output rank
239+ // infer layout for linalg contraction named ops
204240 auto ARank = cast<ShapedType>(linalgOp.getDpsInputs ()[0 ].getType ())
205241 .getShape ()
206242 .size ();
@@ -242,29 +278,36 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
242278 OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
243279 layoutCache[linalgOp] = suggestedLayout;
244280 } else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
281+ !mlir::linalg::isaConvolutionOpInterface (linalgOp) &&
245282 !supportedContractionNamedOpList (linalgOp)) {
283+ // infer layout for non-contraction/non-convolution linalg named ops
284+ // and linalg generic ops
246285 SmallVector<TensorLayout> inputLayouts, outputLayouts;
247286 size_t targetIdx = getTargetInputIdx (curInputLayouts);
248- // TODO(yifei): wisely choose the input format basis
249- // Let's only refer to input[0] for now
250287 for (size_t i = 0 ; i < curInputs.size (); ++i) {
251288 // getMatchingIndexingMap
252289 if (i != targetIdx) {
253- auto res = inferIndexingMapRelation (
290+ auto indexRelation = inferIndexingMapRelation (
254291 linalgOp.getMatchingIndexingMap (curInputs[targetIdx]),
255292 linalgOp.getMatchingIndexingMap (curInputs[i]));
293+ if (failed (indexRelation)) {
294+ return WalkResult::skip ();
295+ }
256296 TensorLayout inputLayout =
257- * inferTargetLayout (curInputLayouts[targetIdx], *res );
297+ inferTargetLayout (curInputLayouts[targetIdx], *indexRelation );
258298 inputLayouts.push_back (inputLayout);
259299 } else {
260300 inputLayouts.push_back (curInputLayouts[targetIdx]);
261301 }
262302 }
263- auto res_out = inferIndexingMapRelation (
303+ auto indexRelation = inferIndexingMapRelation (
264304 linalgOp.getMatchingIndexingMap (curInputs[targetIdx]),
265305 linalgOp.getIndexingMapMatchingResult (curResults[0 ]));
306+ if (failed (indexRelation)) {
307+ return WalkResult::skip ();
308+ }
266309 TensorLayout outputLayout =
267- * inferTargetLayout (curInputLayouts[targetIdx], *res_out );
310+ inferTargetLayout (curInputLayouts[targetIdx], *indexRelation );
268311 outputLayouts.push_back (outputLayout);
269312 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
270313 layoutCache[linalgOp] = suggestedLayout;
@@ -283,52 +326,44 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
283326 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
284327 layoutCache[padOp] = suggestedLayout;
285328 } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
286- auto reassociation = expandShapeOp.getReassociation ();
329+ SmallVector<ReassociationIndices> reassocIndices =
330+ expandShapeOp.getReassociationIndices ();
287331 auto staticOutputShape = expandShapeOp.getStaticOutputShape ();
288332 auto parent = expandShapeOp.getSrc ().getDefiningOp ();
289333 auto inputShape = expandShapeOp.getSrcType ().getShape ();
290334 TensorLayout curInputLayout =
291335 layoutCache.find (parent) != layoutCache.end ()
292336 ? layoutCache[parent].getOutputLayout (0 )
293337 : TensorLayout::createPlainLayout (inputShape.size ());
294- DenseMap<int64_t , int64_t > outputInputIdxMapping, inputOutputIndexMapping;
295- int64_t accumulationOffset = 0 ;
296- for (int64_t i = 0 ; i < static_cast <int64_t >(reassociation.size ()); ++i) {
297- auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
298- for (int64_t j = 0 ; j < static_cast <int64_t >(subReassociation.size ());
299- ++j) {
300- if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
301- outputInputIdxMapping[accumulationOffset + j] = i;
302- inputOutputIndexMapping[i] = accumulationOffset + j;
303- }
304- }
305- accumulationOffset += subReassociation.size ();
338+ SmallVector<int64_t > innerTileSizes;
339+ auto intTileSizes = getConstantIntValues (curInputLayout.getTileSizes ());
340+ if (intTileSizes) {
341+ innerTileSizes = *intTileSizes;
306342 }
307- auto inputOuterAxis = curInputLayout.getOuterAxis ();
308- auto inputInnerAxis = curInputLayout.getInnerAxis ();
309- int64_t diffDifference = staticOutputShape.size () - inputShape.size ();
310- int64_t startIdx = 0 ;
311- SmallVector<int64_t > outputOuterAxis, outputInnerAxis;
312- for (int64_t i = 0 ; i < static_cast <int64_t >(staticOutputShape.size ());
313- ++i) {
314- if (outputInputIdxMapping.find (i) != outputInputIdxMapping.end ()) {
315- outputOuterAxis.push_back (inputOuterAxis[outputInputIdxMapping[i]] +
316- diffDifference);
317- } else {
318- outputOuterAxis.push_back (startIdx++);
319- }
343+ ArrayRef<int64_t > innerDimsPos = curInputLayout.getInnerAxis ();
344+ ArrayRef<int64_t > outerDimsPerm = curInputLayout.getOuterAxis ();
345+ SmallVector<int64_t > projectedInnerDimsPos =
346+ projectToInnerMostNonUnitDimsPos (curInputLayout.getInnerAxis (),
347+ reassocIndices, staticOutputShape);
348+
349+ if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, staticOutputShape,
350+ innerTileSizes)) {
351+ return WalkResult::skip ();
320352 }
321- for (int64_t i = 0 ; i < static_cast <int64_t >(inputInnerAxis.size ());
322- ++i) {
323- outputInnerAxis.push_back (inputOutputIndexMapping[inputInnerAxis[i]]);
353+ SmallVector<int64_t > newOuterDimsPerm;
354+ for (auto outerPos : outerDimsPerm) {
355+ newOuterDimsPerm.insert (newOuterDimsPerm.end (),
356+ reassocIndices[outerPos].begin (),
357+ reassocIndices[outerPos].end ());
324358 }
325- TensorLayout outputLayout (outputOuterAxis, outputInnerAxis ,
359+ TensorLayout outputLayout (newOuterDimsPerm, projectedInnerDimsPos ,
326360 curInputLayout.getTileSizes ());
327361 SmallVector<TensorLayout> inputLayouts{curInputLayout},
328362 outputLayouts{outputLayout};
329363 OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
330364 layoutCache[expandShapeOp] = suggestedLayout;
331365 }
366+ return WalkResult::advance ();
332367 });
333368}
334369
0 commit comments