@@ -5776,6 +5776,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
5776
5776
// ConstantMaskOp
5777
5777
// ===----------------------------------------------------------------------===//
5778
5778
5779
+ void ConstantMaskOp::build (OpBuilder &builder, OperationState &result,
5780
+ VectorType type, ConstantMaskKind kind) {
5781
+ assert (kind == ConstantMaskKind::AllTrue ||
5782
+ kind == ConstantMaskKind::AllFalse);
5783
+ build (builder, result, type,
5784
+ kind == ConstantMaskKind::AllTrue
5785
+ ? type.getShape ()
5786
+ : SmallVector<int64_t >(type.getRank (), 0 ));
5787
+ }
5788
+
5779
5789
LogicalResult ConstantMaskOp::verify () {
5780
5790
auto resultType = llvm::cast<VectorType>(getResult ().getType ());
5781
5791
// Check the corner case of 0-D vectors first.
@@ -5858,6 +5868,21 @@ LogicalResult CreateMaskOp::verify() {
5858
5868
return success ();
5859
5869
}
5860
5870
5871
+ std::optional<int64_t > vector::getConstantVscaleMultiplier (Value value) {
5872
+ if (value.getDefiningOp <vector::VectorScaleOp>())
5873
+ return 1 ;
5874
+ auto mul = value.getDefiningOp <arith::MulIOp>();
5875
+ if (!mul)
5876
+ return {};
5877
+ auto lhs = mul.getLhs ();
5878
+ auto rhs = mul.getRhs ();
5879
+ if (lhs.getDefiningOp <vector::VectorScaleOp>())
5880
+ return getConstantIntValue (rhs);
5881
+ if (rhs.getDefiningOp <vector::VectorScaleOp>())
5882
+ return getConstantIntValue (lhs);
5883
+ return {};
5884
+ }
5885
+
5861
5886
namespace {
5862
5887
5863
5888
// / Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5889,73 +5914,51 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5889
5914
5890
5915
LogicalResult matchAndRewrite (CreateMaskOp createMaskOp,
5891
5916
PatternRewriter &rewriter) const override {
5892
- VectorType retTy = createMaskOp.getResult ().getType ();
5893
- bool isScalable = retTy.isScalable ();
5894
-
5895
- // Check every mask operand
5896
- for (auto [opIdx, operand] : llvm::enumerate (createMaskOp.getOperands ())) {
5897
- if (auto cst = getConstantIntValue (operand)) {
5898
- // Most basic case - this operand is a constant value. Note that for
5899
- // scalable dimensions, CreateMaskOp can be folded only if the
5900
- // corresponding operand is negative or zero.
5901
- if (retTy.getScalableDims ()[opIdx] && *cst > 0 )
5902
- return failure ();
5903
-
5904
- continue ;
5905
- }
5906
-
5907
- // Non-constant operands are not allowed for non-scalable vectors.
5908
- if (!isScalable)
5909
- return failure ();
5910
-
5911
- // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5912
- // true" mask, so can also be treated as constant.
5913
- auto mul = operand.getDefiningOp <arith::MulIOp>();
5914
- if (!mul)
5915
- return failure ();
5916
- auto mulLHS = mul.getRhs ();
5917
- auto mulRHS = mul.getLhs ();
5918
- bool isOneOpVscale =
5919
- (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp ()) ||
5920
- isa<vector::VectorScaleOp>(mulRHS.getDefiningOp ()));
5921
-
5922
- auto isConstantValMatchingDim =
5923
- [=, dim = retTy.getShape ()[opIdx]](Value operand) {
5924
- auto constantVal = getConstantIntValue (operand);
5925
- return (constantVal.has_value () && constantVal.value () == dim);
5926
- };
5927
-
5928
- bool isOneOpConstantMatchingDim =
5929
- isConstantValMatchingDim (mulLHS) || isConstantValMatchingDim (mulRHS);
5930
-
5931
- if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5932
- return failure ();
5917
+ VectorType maskType = createMaskOp.getVectorType ();
5918
+ ArrayRef<int64_t > maskTypeDimSizes = maskType.getShape ();
5919
+ ArrayRef<bool > maskTypeDimScalableFlags = maskType.getScalableDims ();
5920
+
5921
+ // Special case: Rank zero shape.
5922
+ constexpr std::array<int64_t , 1 > rankZeroShape{1 };
5923
+ constexpr std::array<bool , 1 > rankZeroScalableDims{false };
5924
+ if (maskType.getRank () == 0 ) {
5925
+ maskTypeDimSizes = rankZeroShape;
5926
+ maskTypeDimScalableFlags = rankZeroScalableDims;
5933
5927
}
5934
5928
5935
- // Gather constant mask dimension sizes.
5936
- SmallVector<int64_t , 4 > maskDimSizes;
5937
- maskDimSizes.reserve (createMaskOp->getNumOperands ());
5938
- for (auto [operand, maxDimSize] : llvm::zip_equal (
5939
- createMaskOp.getOperands (), createMaskOp.getType ().getShape ())) {
5940
- std::optional dimSize = getConstantIntValue (operand);
5941
- if (!dimSize) {
5942
- // Although not a constant, it is safe to assume that `operand` is
5943
- // "vscale * maxDimSize".
5944
- maskDimSizes.push_back (maxDimSize);
5945
- continue ;
5946
- }
5947
- int64_t dimSizeVal = std::min (dimSize.value (), maxDimSize);
5948
- // If one of dim sizes is zero, set all dims to zero.
5949
- if (dimSize <= 0 ) {
5950
- maskDimSizes.assign (createMaskOp.getType ().getRank (), 0 );
5951
- break ;
5929
+ // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
5930
+ // collect the `constantDims` (for the ConstantMaskOp).
5931
+ SmallVector<int64_t , 4 > constantDims;
5932
+ for (auto [i, dimSize] : llvm::enumerate (createMaskOp.getOperands ())) {
5933
+ if (auto intSize = getConstantIntValue (dimSize)) {
5934
+ // Constant value.
5935
+ // If the mask dim is non-scalable this can be any value.
5936
+ // If the mask dim is scalable only zero (all-false) is supported.
5937
+ if (maskTypeDimScalableFlags[i] && intSize >= 0 )
5938
+ return failure ();
5939
+ constantDims.push_back (*intSize);
5940
+ } else if (auto vscaleMultiplier = getConstantVscaleMultiplier (dimSize)) {
5941
+ // Constant vscale multiple (e.g. 4 x vscale).
5942
+ // Must be all-true to fold to a ConstantMask.
5943
+ if (vscaleMultiplier < maskTypeDimSizes[i])
5944
+ return failure ();
5945
+ constantDims.push_back (*vscaleMultiplier);
5946
+ } else {
5947
+ return failure ();
5952
5948
}
5953
- maskDimSizes.push_back (dimSizeVal);
5954
5949
}
5955
5950
5951
+ // Clamp values to constant_mask bounds.
5952
+ for (auto [value, maskDimSize] : llvm::zip (constantDims, maskTypeDimSizes))
5953
+ value = std::clamp<int64_t >(value, 0 , maskDimSize);
5954
+
5955
+ // If one of dim sizes is zero, set all dims to zero.
5956
+ if (llvm::is_contained (constantDims, 0 ))
5957
+ constantDims.assign (constantDims.size (), 0 );
5958
+
5956
5959
// Replace 'createMaskOp' with ConstantMaskOp.
5957
- rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, retTy ,
5958
- maskDimSizes );
5960
+ rewriter.replaceOpWithNewOp <ConstantMaskOp>(createMaskOp, maskType ,
5961
+ constantDims );
5959
5962
return success ();
5960
5963
}
5961
5964
};
0 commit comments