@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
356
356
357
357
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack (RewriterBase &rewriter,
358
358
tensor::UnPackOp unPackOp) {
359
- // 1. Filter out NYI cases.
360
- if (!unPackOp.getOuterDimsPerm ().empty () &&
361
- !isIdentityPermutation (unPackOp.getOuterDimsPerm ())) {
362
- return rewriter.notifyMatchFailure (unPackOp,
363
- " non-identity outer dims perm NYI" );
364
- }
365
-
366
359
Location loc = unPackOp->getLoc ();
367
360
OpBuilder::InsertionGuard g (rewriter);
368
361
rewriter.setInsertionPoint (unPackOp);
@@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
391
384
return LowerUnPackOpResult{/* emptyOp=*/ nullptr , /* transposeOp=*/ nullptr ,
392
385
/* reshapeOp=*/ nullptr , extractSliceOp};
393
386
}
394
- // 2. Compute the permutation vector to move the last `numPackedDims` into
395
- // the `innerPosDims` of a shape of rank `packedRank`.
396
- int64_t numPackedDims = unPackOp.getInnerDimsPos ().size ();
397
- auto lastDims = llvm::to_vector (
398
- llvm::seq<int64_t >(packedRank - numPackedDims, packedRank));
399
- PackingMetadata packingMetadata =
400
- computePackingMetadata (packedRank, unPackOp.getInnerDimsPos ());
401
- SmallVector<int64_t > lastDimsToInsertPositionsPerm = computePermutationVector (
402
- packedRank, lastDims, packingMetadata.insertPositions );
387
+
388
+ // 2. Compute the permutation vector to shuffle packed shape into the shape
389
+ // before any outer or inner permutations have been applied.
390
+ PackingMetadata packingMetadata;
391
+ SmallVector<int64_t > packedToStripMinedShapePerm =
392
+ tensor::getUnPackInverseSrcPerm (unPackOp, packingMetadata);
403
393
404
394
// 3. Compute the stripMinedShape: this is the packed shape without outer and
405
395
// inner permutations.
406
396
SmallVector<int64_t > stripMinedShape (packedTensorType.getShape ());
407
- applyPermutationToVector (stripMinedShape, lastDimsToInsertPositionsPerm );
397
+ applyPermutationToVector (stripMinedShape, packedToStripMinedShapePerm );
408
398
409
399
// 4. Transpose packedShape to stripMinedShape.
410
400
RankedTensorType stripMinedTensorType =
411
401
RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
412
402
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
413
403
stripMinedTensorType, packingMetadata.reassociations );
414
404
415
- // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
405
+ // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
416
406
// permutation.
417
407
SmallVector<OpFoldResult, 4 > dims =
418
408
tensor::getMixedSizes (rewriter, loc, unPackOp.getSource ());
419
- applyPermutationToVector (dims, lastDimsToInsertPositionsPerm );
409
+ applyPermutationToVector (dims, packedToStripMinedShapePerm );
420
410
auto emptyOp = rewriter.create <tensor::EmptyOp>(
421
411
loc, dims, stripMinedTensorType.getElementType ());
422
412
auto transposeOp = rewriter.create <linalg::TransposeOp>(
423
- loc, unPackOp.getSource (), emptyOp, lastDimsToInsertPositionsPerm );
413
+ loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm );
424
414
425
415
LLVM_DEBUG (
426
416
DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
427
417
DBGS () << " insertPositions: " );
428
418
DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
429
419
DBGS () << " packedShape: " );
430
420
DBGSNL ();
431
- llvm::interleaveComma (lastDimsToInsertPositionsPerm ,
432
- DBGS () << " lastDimsToInsertPositionsPerm : " );
421
+ llvm::interleaveComma (packedToStripMinedShapePerm ,
422
+ DBGS () << " packedToStripMinedShapePerm : " );
433
423
DBGSNL (); llvm::interleaveComma (
434
424
packingMetadata.reassociations , DBGS () << " reassociations: " ,
435
425
[&](ReassociationIndices ri) {
0 commit comments