@@ -1558,6 +1558,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15581558
15591559 RankedTensorType unpackTensorType = unpackOp.getSourceType ();
15601560
1561+ // If the input vector sizes are not provided, then the vector sizes are
1562+ // determined by the result tensor shape. In case the vector sizes aren't
1563+ // provided, we update the inBounds attribute instead of masking.
1564+ bool useInBoundsInsteadOfMasking = true ;
1565+ if (inputVectorSizes.empty ()) {
1566+ ArrayRef<int64_t > resultTensorShape = unpackOp.getDestType ().getShape ();
1567+ inputVectorSizes = resultTensorShape.take_front (unpackOp.getSourceRank ());
1568+ useInBoundsInsteadOfMasking = false ;
1569+ }
1570+
15611571 ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
15621572 ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
15631573
@@ -1612,7 +1622,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16121622 // to shape of source, then a mask is necessary.
16131623 Value readResult = vector::createReadOrMaskedRead (
16141624 rewriter, loc, unpackOp.getSource (),
1615- ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue);
1625+ ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue,
1626+ doMasking);
16161627
16171628 PackingMetadata packMetadata;
16181629 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1753,8 +1764,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
17531764 LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
17541765 return failure ();
17551766 }
1756- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1757- if (!inputVectorSizes.empty () &&
1767+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1768+ bool satisfyEmptyCond = true ;
1769+ if (inputVectorSizes.empty ()) {
1770+ if (!unpackOp.getDestType ().hasStaticShape () ||
1771+ !unpackOp.getSourceType ().hasStaticShape ())
1772+ satisfyEmptyCond = false ;
1773+ }
1774+ if (!satisfyEmptyCond &&
17581775 failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
17591776 return failure ();
17601777
0 commit comments