Skip to content

Commit 259d1b0

Browse files
committed
Share logic with CreateMaskFolder
The main thing shared here is the `getConstantVscaleMultiplier()` matcher, I could not think of a good way to share all the logic as it's somewhat different.
1 parent 692ce6e commit 259d1b0

File tree

3 files changed

+54
-79
lines changed

3 files changed

+54
-79
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
166166
SmallVector<arith::ConstantIndexOp>
167167
getAsConstantIndexOps(ArrayRef<Value> values);
168168

169+
/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
170+
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
171+
/// `std::nullopt`.
172+
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
173+
169174
//===----------------------------------------------------------------------===//
170175
// Vector Masking Utilities
171176
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 49 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5841,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() {
58415841
return success();
58425842
}
58435843

5844+
std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
5845+
if (value.getDefiningOp<vector::VectorScaleOp>())
5846+
return 1;
5847+
auto mul = value.getDefiningOp<arith::MulIOp>();
5848+
if (!mul)
5849+
return {};
5850+
auto lhs = mul.getLhs();
5851+
auto rhs = mul.getRhs();
5852+
if (lhs.getDefiningOp<vector::VectorScaleOp>())
5853+
return getConstantIntValue(rhs);
5854+
if (rhs.getDefiningOp<vector::VectorScaleOp>())
5855+
return getConstantIntValue(lhs);
5856+
return {};
5857+
}
5858+
58445859
namespace {
58455860

58465861
/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5872,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
58725887

58735888
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
58745889
PatternRewriter &rewriter) const override {
5875-
VectorType retTy = createMaskOp.getResult().getType();
5876-
bool isScalable = retTy.isScalable();
5877-
5878-
// Check every mask operand
5879-
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5880-
if (auto cst = getConstantIntValue(operand)) {
5881-
// Most basic case - this operand is a constant value. Note that for
5882-
// scalable dimensions, CreateMaskOp can be folded only if the
5883-
// corresponding operand is negative or zero.
5884-
if (retTy.getScalableDims()[opIdx] && *cst > 0)
5885-
return failure();
5886-
5887-
continue;
5888-
}
5889-
5890-
// Non-constant operands are not allowed for non-scalable vectors.
5891-
if (!isScalable)
5892-
return failure();
5893-
5894-
// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5895-
// true" mask, so can also be treated as constant.
5896-
auto mul = operand.getDefiningOp<arith::MulIOp>();
5897-
if (!mul)
5898-
return failure();
5899-
auto mulLHS = mul.getRhs();
5900-
auto mulRHS = mul.getLhs();
5901-
bool isOneOpVscale =
5902-
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5903-
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5904-
5905-
auto isConstantValMatchingDim =
5906-
[=, dim = retTy.getShape()[opIdx]](Value operand) {
5907-
auto constantVal = getConstantIntValue(operand);
5908-
return (constantVal.has_value() && constantVal.value() == dim);
5909-
};
5910-
5911-
bool isOneOpConstantMatchingDim =
5912-
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5913-
5914-
if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5915-
return failure();
5890+
VectorType maskType = createMaskOp.getVectorType();
5891+
ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
5892+
ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
5893+
5894+
// Special case: Rank zero shape.
5895+
constexpr std::array<int64_t, 1> rankZeroShape{1};
5896+
constexpr std::array<bool, 1> rankZeroScalableDims{false};
5897+
if (maskType.getRank() == 0) {
5898+
maskTypeDimSizes = rankZeroShape;
5899+
maskTypeDimScalableFlags = rankZeroScalableDims;
59165900
}
59175901

5918-
// Gather constant mask dimension sizes.
5919-
SmallVector<int64_t, 4> maskDimSizes;
5920-
maskDimSizes.reserve(createMaskOp->getNumOperands());
5921-
for (auto [operand, maxDimSize] : llvm::zip_equal(
5922-
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5923-
std::optional dimSize = getConstantIntValue(operand);
5924-
if (!dimSize) {
5925-
// Although not a constant, it is safe to assume that `operand` is
5926-
// "vscale * maxDimSize".
5927-
maskDimSizes.push_back(maxDimSize);
5928-
continue;
5929-
}
5930-
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
5931-
// If one of dim sizes is zero, set all dims to zero.
5932-
if (dimSize <= 0) {
5933-
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5934-
break;
5902+
SmallVector<int64_t, 4> constantDims;
5903+
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
5904+
if (auto intSize = getConstantIntValue(dimSize)) {
5905+
// Non scalable dims can have any value. Scalable dims can only be zero.
5906+
if (intSize >= 0 && maskTypeDimScalableFlags[i])
5907+
return failure();
5908+
constantDims.push_back(*intSize);
5909+
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
5910+
// Scalable dims must be all-true.
5911+
if (vscaleMultiplier < maskTypeDimSizes[i])
5912+
return failure();
5913+
constantDims.push_back(*vscaleMultiplier);
5914+
} else {
5915+
return failure();
59355916
}
5936-
maskDimSizes.push_back(dimSizeVal);
59375917
}
59385918

5919+
// Clamp values to constant_mask bounds.
5920+
for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
5921+
value = std::clamp<int64_t>(value, 0, maskDimSize);
5922+
5923+
// If one of dim sizes is zero, set all dims to zero.
5924+
if (llvm::is_contained(constantDims, 0))
5925+
constantDims.assign(constantDims.size(), 0);
5926+
59395927
// Replace 'createMaskOp' with ConstantMaskOp.
5940-
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
5941-
maskDimSizes);
5928+
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
5929+
constantDims);
59425930
return success();
59435931
}
59445932
};

mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,6 @@ using namespace mlir;
1717
using namespace mlir::vector;
1818
namespace {
1919

20-
/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
21-
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
22-
/// `std::nullopt`.
23-
std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
24-
if (value.getDefiningOp<vector::VectorScaleOp>())
25-
return 1;
26-
auto mul = value.getDefiningOp<arith::MulIOp>();
27-
if (!mul)
28-
return {};
29-
auto lhs = mul.getLhs();
30-
auto rhs = mul.getRhs();
31-
if (lhs.getDefiningOp<vector::VectorScaleOp>())
32-
return getConstantIntValue(rhs);
33-
if (rhs.getDefiningOp<vector::VectorScaleOp>())
34-
return getConstantIntValue(lhs);
35-
return {};
36-
}
37-
3820
/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
3921
/// All-true masks can then be eliminated by simple folds.
4022
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,

0 commit comments

Comments
 (0)