@@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
14141414// / create an empty destination tensor and create a TransferWriteOp from the
14151415// / input to the empty tensor. If the destination shape is not the same as the
14161416// / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417- // / mask for the write.
1417+ // / mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1418+ // / inBounds attribute of the transfer write op instead of masking.
14181419static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
14191420 Value input,
14201421 SmallVector<OpFoldResult> destSizes,
1421- ArrayRef<int64_t > inputVectorSizes) {
1422+ ArrayRef<int64_t > inputVectorSizes,
1423+ bool useInBoundsInsteadOfMasking) {
1424+
14221425 auto inputType = cast<VectorType>(input.getType ());
14231426 Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
14241427 inputType.getElementType ());
14251428 int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
14261429 auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1430+ auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1431+ SmallVector<bool > inBoundsVal (rank, true );
1432+ if (useInBoundsInsteadOfMasking) {
1433+ // Update the inBounds attribute.
1434+ for (unsigned i = 0 ; i < rank; i++)
1435+ inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1436+ !ShapedType::isDynamic (destShape[i]);
1437+ }
14271438 Operation *write = builder.create <vector::TransferWriteOp>(
14281439 loc,
14291440 /* vector=*/ input,
14301441 /* source=*/ dest,
14311442 /* indices=*/ SmallVector<Value>(rank, zero),
1432- /* inBounds=*/ SmallVector<bool >(rank, true ));
1433- auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1443+ /* inBounds=*/ inBoundsVal);
14341444 assert (llvm::none_of (
14351445 destShape.drop_front (inputVectorSizes.size ()),
14361446 [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
14371447 " Only dims aligned with inputVectorSizes may be dynamic" );
1448+ if (useInBoundsInsteadOfMasking)
1449+ return write;
14381450 bool needMaskForWrite = !llvm::equal (
14391451 inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
14401452 if (needMaskForWrite) {
@@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15351547 loc, shapeCastOp.getResult (), destPermutation);
15361548
15371549 // Create TransferWriteOp.
1538- Operation *write =
1539- createWriteOrMaskedWrite ( rewriter, loc, transposeOp.getResult (),
1540- reifiedReturnShapes[ 0 ], inputVectorSizes );
1550+ Operation *write = createWriteOrMaskedWrite (
1551+ rewriter, loc, transposeOp.getResult (), reifiedReturnShapes[ 0 ] ,
1552+ inputVectorSizes, /* useInBoundsInsteadOfMasking= */ false );
15411553 newResults.push_back (write->getResult (0 ));
15421554 return success ();
15431555}
@@ -1547,7 +1559,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15471559// / vector::TransposeOp - Transpose the Source tensor
15481560// / ShapeCastOp - Reshape the data based on the target.
15491561// / vector::TransferWriteOp. - Write the result vector back to the destination
1550- // / tensor
1562+ // / tensor.
1563+ // / If the vector sizes are not provided:
1564+ // / * the vector sizes are determined by the input operand and attributes,
1565+ // / * update the inBounds attribute instead of masking.
15511566static LogicalResult
15521567vectorizeAsTensorUnpackOp (RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15531568 ArrayRef<int64_t > inputVectorSizes,
@@ -1560,40 +1575,61 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
15601575
15611576 ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
15621577 ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1563-
1564- SmallVector<int64_t > readMaskShape (inputVectorSizes.begin (),
1565- inputVectorSizes.end ());
1566- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
15671578 ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1579+ bool useInBoundsInsteadOfMasking = false ;
1580+ ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1581+
1582+ auto destSize = unpackOp.getDestRank ();
15681583
1569- // ReadMask is the size of tensor used to read and apply mask. It is
1584+ // vectorSizes is the shape of the vector that will be used to do final
1585+ // write on the destination tensor. It is set like this: Let's say the
1586+ // sourceShape is 'M' and the vectorSize (VS) array is size 'N' where N <= M.
1587+ // Thus:
1588+ // - vectorSizes = sourceShape.take_front(N)
1589+ // - if outer_dims_perms is present: do that permutation on initVectorShape.
1590+ // - Multiply all the locations pointed by innerDimPos by the innerTileSize
1591+ // attribute value.
1592+ SmallVector<int64_t > vectorSizes (inputVectorSizes);
1593+ if (vectorSizes.empty ()) {
1594+ llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1595+ if (!outerDimsPerm.empty ())
1596+ applyPermutationToVector (vectorSizes, outerDimsPerm);
1597+ for (auto [i, pos] : llvm::enumerate (innerDimPos))
1598+ vectorSizes[pos] *= innerTiles[i];
1599+
1600+ useInBoundsInsteadOfMasking = true ;
1601+ }
1602+
1603+ SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1604+
1605+ // readVectorSizes is the size of tensor used to read and apply mask. It is
15701606 // set like this: Let's say the vectorSize (VS) array is size 'N' and
15711607 // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
15721608 // size M-N
15731609 // Thus:
1574- // - initially: ReadMaskShape = vectorInputSizes
1610+ // - initially: readVectorSizes = vectorInputSizes
15751611 // - Divide all the readMaskShape locations pointed by innerDimPos
15761612 // by the innerTileSize attribute value.
1577- // - if outer_dims_perms is present: do that permutation on readMaskShape .
1613+ // - if outer_dims_perms is present: do that permutation on readVectorSizes .
15781614 // - Append the remaining shape from SS
15791615 // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
15801616 // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
15811617 // 128] and outer_dims_perm is [1, 0] then read shape is:
1582- // ReadMaskShape (initial): [512, 128]
1618+ // ReadVectorSizes (initial): [512, 128]
15831619 // Final Value(after innerDim Adjustment): [512/32, 128/16]
15841620 // = [16, 8]
15851621 // After applying outer_dims_perm: [8, 16]
15861622 // After appending the rest of the sourceShape: [8, 16, 32, 16]
15871623
15881624 for (auto [index, size] : enumerate(innerTiles)) {
1589- readMaskShape [innerDimPos[index]] =
1590- llvm::divideCeil (readMaskShape [innerDimPos[index]], size);
1625+ readVectorSizes [innerDimPos[index]] =
1626+ llvm::divideCeil (readVectorSizes [innerDimPos[index]], size);
15911627 }
15921628 if (!outerDimsPerm.empty ()) {
1593- applyPermutationToVector (readMaskShape , outerDimsPerm);
1629+ applyPermutationToVector (readVectorSizes , outerDimsPerm);
15941630 }
1595- readMaskShape .append (sourceShape.begin () + inputVectorSizes .size (),
1596- sourceShape.end ());
1631+ readVectorSizes .append (sourceShape.begin () + vectorSizes .size (),
1632+ sourceShape.end ());
15971633
15981634 ReifiedRankedShapedTypeDims reifiedRetShapes;
15991635 LogicalResult status =
@@ -1611,8 +1647,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16111647 // Read result, mask if necessary. If transferReadOp shape is not equal
16121648 // to shape of source, then a mask is necessary.
16131649 Value readResult = vector::createReadOrMaskedRead (
1614- rewriter, loc, unpackOp.getSource (),
1615- ArrayRef<int64_t >(readMaskShape.begin (), readMaskShape.end ()), padValue,
1650+ rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
16161651 /* useInBoundsInsteadOfMasking=*/ false );
16171652
16181653 PackingMetadata packMetadata;
@@ -1636,15 +1671,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16361671 vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
16371672 loc, vecCollapsedType, transposeOp->getResult (0 ));
16381673
1639- // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1674+ // writeVectorSizes had to match the shapecast shape for dynamic sizes,
16401675 // otherwise the validator complains that the mask size is invalid.
1641- SmallVector<int64_t > writeMaskShape (
1676+ SmallVector<int64_t > writeVectorSizes (
16421677 unpackOp.getDestType ().hasStaticShape ()
1643- ? inputVectorSizes
1678+ ? vectorSizes
16441679 : shapeCastOp.getResultVectorType ().getShape ());
1645- Operation *write =
1646- createWriteOrMaskedWrite ( rewriter, loc, shapeCastOp.getResult (),
1647- reifiedRetShapes[ 0 ], writeMaskShape );
1680+ Operation *write = createWriteOrMaskedWrite (
1681+ rewriter, loc, shapeCastOp.getResult (), reifiedRetShapes[ 0 ] ,
1682+ writeVectorSizes, useInBoundsInsteadOfMasking );
16481683 newResults.push_back (write->getResult (0 ));
16491684 return success ();
16501685}
@@ -1673,7 +1708,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
16731708 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
16741709 /* useInBoundsInsteadOfMasking=*/ false );
16751710 Operation *write = createWriteOrMaskedWrite (
1676- rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes);
1711+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes,
1712+ /* useInBoundsInsteadOfMasking=*/ false );
16771713 newResults.push_back (write->getResult (0 ));
16781714 return success ();
16791715}
@@ -1755,8 +1791,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
17551791 LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
17561792 return failure ();
17571793 }
1758- llvm::ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1759- if (!inputVectorSizes.empty () &&
1794+ ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
1795+ bool satisfyEmptyCond = inputVectorSizes.empty () &&
1796+ unpackOp.getDestType ().hasStaticShape () &&
1797+ unpackOp.getSourceType ().hasStaticShape ();
1798+ if (!satisfyEmptyCond &&
17601799 failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
17611800 return failure ();
17621801
0 commit comments