@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
356356
357357FailureOr<LowerUnPackOpResult> linalg::lowerUnPack (RewriterBase &rewriter,
358358 tensor::UnPackOp unPackOp) {
359- // 1. Filter out NYI cases.
360- if (!unPackOp.getOuterDimsPerm ().empty () &&
361- !isIdentityPermutation (unPackOp.getOuterDimsPerm ())) {
362- return rewriter.notifyMatchFailure (unPackOp,
363- " non-identity outer dims perm NYI" );
364- }
365-
366359 Location loc = unPackOp->getLoc ();
367360 OpBuilder::InsertionGuard g (rewriter);
368361 rewriter.setInsertionPoint (unPackOp);
@@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
391384 return LowerUnPackOpResult{/* emptyOp=*/ nullptr , /* transposeOp=*/ nullptr ,
392385 /* reshapeOp=*/ nullptr , extractSliceOp};
393386 }
394- // 2. Compute the permutation vector to move the last `numPackedDims` into
395- // the `innerPosDims` of a shape of rank `packedRank`.
396- int64_t numPackedDims = unPackOp.getInnerDimsPos ().size ();
397- auto lastDims = llvm::to_vector (
398- llvm::seq<int64_t >(packedRank - numPackedDims, packedRank));
399- PackingMetadata packingMetadata =
400- computePackingMetadata (packedRank, unPackOp.getInnerDimsPos ());
401- SmallVector<int64_t > lastDimsToInsertPositionsPerm = computePermutationVector (
402- packedRank, lastDims, packingMetadata.insertPositions );
387+
388+ // 2. Compute the permutation vector to shuffle packed shape into the shape
389+ // before any outer or inner permutations have been applied.
390+ PackingMetadata packingMetadata;
391+ SmallVector<int64_t > packedToStripMinedShapePerm =
392+ tensor::getUnPackInverseSrcPerm (unPackOp, packingMetadata);
403393
404394 // 3. Compute the stripMinedShape: this is the packed shape without outer and
405395 // inner permutations.
406396 SmallVector<int64_t > stripMinedShape (packedTensorType.getShape ());
407- applyPermutationToVector (stripMinedShape, lastDimsToInsertPositionsPerm );
397+ applyPermutationToVector (stripMinedShape, packedToStripMinedShapePerm );
408398
409399 // 4. Transpose packedShape to stripMinedShape.
410400 RankedTensorType stripMinedTensorType =
411401 RankedTensorType::Builder (packedTensorType).setShape (stripMinedShape);
412402 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
413403 stripMinedTensorType, packingMetadata.reassociations );
414404
415- // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
405+ // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
416406 // permutation.
417407 SmallVector<OpFoldResult, 4 > dims =
418408 tensor::getMixedSizes (rewriter, loc, unPackOp.getSource ());
419- applyPermutationToVector (dims, lastDimsToInsertPositionsPerm );
409+ applyPermutationToVector (dims, packedToStripMinedShapePerm );
420410 auto emptyOp = rewriter.create <tensor::EmptyOp>(
421411 loc, dims, stripMinedTensorType.getElementType ());
422412 auto transposeOp = rewriter.create <linalg::TransposeOp>(
423- loc, unPackOp.getSource (), emptyOp, lastDimsToInsertPositionsPerm );
413+ loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm );
424414
425415 LLVM_DEBUG (
426416 DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
427417 DBGS () << " insertPositions: " );
428418 DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
429419 DBGS () << " packedShape: " );
430420 DBGSNL ();
431- llvm::interleaveComma (lastDimsToInsertPositionsPerm ,
432- DBGS () << " lastDimsToInsertPositionsPerm : " );
421+ llvm::interleaveComma (packedToStripMinedShapePerm ,
422+ DBGS () << " packedToStripMinedShapePerm : " );
433423 DBGSNL (); llvm::interleaveComma (
434424 packingMetadata.reassociations , DBGS () << " reassociations: " ,
435425 [&](ReassociationIndices ri) {
0 commit comments