@@ -38,6 +38,12 @@ using namespace mlir::tensor;
3838static SmallVector<int64_t > getPackedAxes (ArrayRef<int64_t > dimensions,
3939 TensorLayout targetLayout) {
4040 SmallVector<int64_t > result (dimensions);
41+ // permuting on outer axis
42+ auto outerPerm = targetLayout.getOuterAxis ();
43+ for (size_t i = 0 ; i < dimensions.size (); ++i) {
44+ result[i] = outerPerm[dimensions[i]];
45+ }
46+ // inserting inner axis
4147 auto innerPos = targetLayout.getInnerAxis ();
4248 for (size_t i = 0 ; i < dimensions.size (); ++i) {
4349 if (std::find (innerPos.begin (), innerPos.end (), dimensions[i]) !=
@@ -153,8 +159,10 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
153159 loc, inits.getTypes (), inputs, inits, packedAxes);
154160 packedLinalgOp->getRegion (0 ).takeBody (linalgOp->getRegion (0 ));
155161 } else if (auto broadcastOp = dyn_cast<linalg::BroadcastOp>(&linalgOp)) {
156- packedLinalgOp = rewriter.create <linalg::BroadcastOp>(
157- loc, inputs[0 ], inits[0 ], broadcastOp->getDimensions ());
162+ SmallVector<int64_t > packedAxes =
163+ getPackedAxes (broadcastOp->getDimensions (), initLayouts[0 ]);
164+ packedLinalgOp = rewriter.create <linalg::BroadcastOp>(loc, inputs[0 ],
165+ inits[0 ], packedAxes);
158166 } else if (auto transposeOp = dyn_cast<linalg::TransposeOp>(&linalgOp)) {
159167 SmallVector<int64_t > packedPermAxes = getPackedPermAxes (
160168 transposeOp->getPermutation (), inputLayouts[0 ], initLayouts[0 ]);
@@ -237,39 +245,41 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
237245 return WalkResult::skip ();
238246 }
239247 } else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
240- Location loc = expandShapeOp->getLoc ();
241- auto inputLayout = opLayout->getSupportedInputLayouts ()[0 ];
242- auto outputLayout = opLayout->getSupportedOutputLayouts ()[0 ];
243- Value dest = tensor::PackOp::createDestinationTensor (
244- rewriter, loc, expandShapeOp.getSrc (), inputLayout.getTileSizes (),
245- inputLayout.getInnerAxis (), inputLayout.getOuterAxis ());
246- Value packedSource = rewriter.create <tensor::PackOp>(
247- loc, expandShapeOp.getSrc (), dest, inputLayout.getInnerAxis (),
248- inputLayout.getTileSizes (), std::nullopt ,
249- inputLayout.getOuterAxis ());
250- auto resultType = RankedTensorType::get (
251- expandShapeOp.getStaticOutputShape (),
252- expandShapeOp.getSrcType ().getElementType ());
253- RankedTensorType resultPackType = tensor::PackOp::inferPackedType (
254- resultType, vector::getAsIntegers (outputLayout.getTileSizes ()),
255- outputLayout.getInnerAxis (), outputLayout.getOuterAxis ());
256- auto reassocExpand = getReassociationIndicesForReshape (
257- cast<ShapedType>(dest.getType ()), resultPackType);
258- auto packedExpandShape = rewriter.create <tensor::ExpandShapeOp>(
259- loc, expandShapeOp.getSrcType ().getElementType (), packedSource,
260- *reassocExpand);
261- Value result = rewriter.create <tensor::UnPackOp>(
262- packedExpandShape->getLoc (), packedExpandShape, packedExpandShape,
263- outputLayout.getInnerAxis (), outputLayout.getTileSizes (),
264- outputLayout.getOuterAxis ());
265- rewriter.replaceOp (expandShapeOp, result);
248+ // Location loc = expandShapeOp->getLoc();
249+ // auto inputLayout = opLayout->getSupportedInputLayouts()[0];
250+ // auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
251+ // Value dest = tensor::PackOp::createDestinationTensor(
252+ // rewriter, loc, expandShapeOp.getSrc(),
253+ // inputLayout.getTileSizes(), inputLayout.getInnerAxis(),
254+ // inputLayout.getOuterAxis());
255+ // Value packedSource = rewriter.create<tensor::PackOp>(
256+ // loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
257+ // inputLayout.getTileSizes(), std::nullopt,
258+ // inputLayout.getOuterAxis());
259+ // auto resultType = RankedTensorType::get(
260+ // expandShapeOp.getStaticOutputShape(),
261+ // expandShapeOp.getSrcType().getElementType());
262+ // RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
263+ // resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
264+ // outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
265+ // auto reassocExpand = getReassociationIndicesForReshape(
266+ // cast<ShapedType>(dest.getType()), resultPackType);
267+ // auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
268+ // loc, expandShapeOp.getSrcType().getElementType(), packedSource,
269+ // *reassocExpand);
270+ // Value result = rewriter.create<tensor::UnPackOp>(
271+ // packedExpandShape->getLoc(), packedExpandShape,
272+ // packedExpandShape, outputLayout.getInnerAxis(),
273+ // outputLayout.getTileSizes(), outputLayout.getOuterAxis());
274+ // rewriter.replaceOp(expandShapeOp, result);
266275 }
267276 }
268277 }
269278 return WalkResult::advance ();
270279 });
271280 if (walk.wasSkipped ())
272281 return failure ();
282+ graph->dump ();
273283 return success ();
274284}
275285
0 commit comments