@@ -5841,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() {
5841
5841
return success ();
5842
5842
}
5843
5843
5844
+ std::optional<int64_t > vector::getConstantVscaleMultiplier (Value value) {
5845
+ if (value.getDefiningOp <vector::VectorScaleOp>())
5846
+ return 1 ;
5847
+ auto mul = value.getDefiningOp <arith::MulIOp>();
5848
+ if (!mul)
5849
+ return {};
5850
+ auto lhs = mul.getLhs ();
5851
+ auto rhs = mul.getRhs ();
5852
+ if (lhs.getDefiningOp <vector::VectorScaleOp>())
5853
+ return getConstantIntValue (rhs);
5854
+ if (rhs.getDefiningOp <vector::VectorScaleOp>())
5855
+ return getConstantIntValue (lhs);
5856
+ return {};
5857
+ }
5858
+
5844
5859
namespace {
5845
5860
5846
5861
// / Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5872,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5872
5887
5873
5888
LogicalResult matchAndRewrite (CreateMaskOp createMaskOp,
5874
5889
PatternRewriter &rewriter) const override {
5875
- VectorType retTy = createMaskOp.getResult ().getType ();
5876
- bool isScalable = retTy.isScalable ();
5877
-
5878
- // Check every mask operand
5879
- for (auto [opIdx, operand] : llvm::enumerate (createMaskOp.getOperands ())) {
5880
- if (auto cst = getConstantIntValue (operand)) {
5881
- // Most basic case - this operand is a constant value. Note that for
5882
- // scalable dimensions, CreateMaskOp can be folded only if the
5883
- // corresponding operand is negative or zero.
5884
- if (retTy.getScalableDims ()[opIdx] && *cst > 0 )
5885
- return failure ();
5886
-
5887
- continue ;
5888
- }
5889
-
5890
- // Non-constant operands are not allowed for non-scalable vectors.
5891
- if (!isScalable)
5892
- return failure ();
5893
-
5894
- // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5895
- // true" mask, so can also be treated as constant.
5896
- auto mul = operand.getDefiningOp <arith::MulIOp>();
5897
- if (!mul)
5898
- return failure ();
5899
- auto mulLHS = mul.getRhs ();
5900
- auto mulRHS = mul.getLhs ();
5901
- bool isOneOpVscale =
5902
- (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp ()) ||
5903
- isa<vector::VectorScaleOp>(mulRHS.getDefiningOp ()));
5904
-
5905
- auto isConstantValMatchingDim =
5906
- [=, dim = retTy.getShape ()[opIdx]](Value operand) {
5907
- auto constantVal = getConstantIntValue (operand);
5908
- return (constantVal.has_value () && constantVal.value () == dim);
5909
- };
5910
-
5911
- bool isOneOpConstantMatchingDim =
5912
- isConstantValMatchingDim (mulLHS) || isConstantValMatchingDim (mulRHS);
5913
-
5914
- if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5915
- return failure ();
5890
+ VectorType maskType = createMaskOp.getVectorType ();
5891
+ ArrayRef<int64_t > maskTypeDimSizes = maskType.getShape ();
5892
+ ArrayRef<bool > maskTypeDimScalableFlags = maskType.getScalableDims ();
5893
+
5894
+ // Special case: Rank zero shape.
5895
+ constexpr std::array<int64_t , 1 > rankZeroShape{1 };
5896
+ constexpr std::array<bool , 1 > rankZeroScalableDims{false };
5897
+ if (maskType.getRank () == 0 ) {
5898
+ maskTypeDimSizes = rankZeroShape;
5899
+ maskTypeDimScalableFlags = rankZeroScalableDims;
5916
5900
}
5917
5901
5918
- // Gather constant mask dimension sizes.
5919
- SmallVector<int64_t , 4 > maskDimSizes;
5920
- maskDimSizes.reserve (createMaskOp->getNumOperands ());
5921
- for (auto [operand, maxDimSize] : llvm::zip_equal (
5922
- createMaskOp.getOperands (), createMaskOp.getType ().getShape ())) {
5923
- std::optional dimSize = getConstantIntValue (operand);
5924
- if (!dimSize) {
5925
- // Although not a constant, it is safe to assume that `operand` is
5926
- // "vscale * maxDimSize".
5927
- maskDimSizes.push_back (maxDimSize);
5928
- continue ;
5929
- }
5930
- int64_t dimSizeVal = std::min (dimSize.value (), maxDimSize);
5931
- // If one of dim sizes is zero, set all dims to zero.
5932
- if (dimSize <= 0 ) {
5933
- maskDimSizes.assign (createMaskOp.getType ().getRank (), 0 );
5934
- break ;
5902
+ SmallVector<int64_t , 4 > constantDims;
5903
+ for (auto [i, dimSize] : llvm::enumerate (createMaskOp.getOperands ())) {
5904
+ if (auto intSize = getConstantIntValue (dimSize)) {
5905
+ // Non scalable dims can have any value. Scalable dims can only be zero.
5906
+ if (intSize >= 0 && maskTypeDimScalableFlags[i])
5907
+ return failure ();
5908
+ constantDims.push_back (*intSize);
5909
+ } else if (auto vscaleMultiplier = getConstantVscaleMultiplier (dimSize)) {
5910
+ // Scalable dims must be all-true.
5911
+ if (vscaleMultiplier < maskTypeDimSizes[i])
5912
+ return failure ();
5913
+ constantDims.push_back (*vscaleMultiplier);
5914
+ } else {
5915
+ return failure ();
5935
5916
}
5936
- maskDimSizes.push_back (dimSizeVal);
5937
5917
}
5938
5918
5919
+ // Clamp values to constant_mask bounds.
5920
+ for (auto [value, maskDimSize] : llvm::zip (constantDims, maskTypeDimSizes))
5921
+ value = std::clamp<int64_t >(value, 0 , maskDimSize);
5922
+
5923
+ // If one of dim sizes is zero, set all dims to zero.
5924
+ if (llvm::is_contained (constantDims, 0 ))
5925
+ constantDims.assign (constantDims.size (), 0 );
5926
+
5939
5927
// Replace 'createMaskOp' with ConstantMaskOp.
5940
- rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, retTy ,
5941
- maskDimSizes );
5928
+ rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, maskType ,
5929
+ constantDims );
5942
5930
return success ();
5943
5931
}
5944
5932
};
0 commit comments