Skip to content

Commit 9b06e25

Browse files
authored
[mlir][vector] Add mask elimination transform (#99314)
This adds a new transform `eliminateVectorMasks()` which aims at removing scalable `vector.create_masks` that will be all-true at runtime. It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands. If the lower bound is >= to the corresponding mask vector type dim, then that dimension of the mask is all true. Note that the pattern matching prevents expensive value-bounds analysis in cases where the mask won't be all true. For example: ```mlir %mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1> ``` From looking at `%c2` we can tell this is not going to be an all-true mask, so we don't need to run the value-bounds analysis for `%dynamicValue` (and can exit the transform early). Note: Eliminating create_masks here means replacing them with all-true constants (which will then lead to the masks folding away).
1 parent badfb4b commit 9b06e25

File tree

8 files changed

+401
-61
lines changed

8 files changed

+401
-61
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ namespace detail {
5656
struct BitmaskEnumStorage;
5757
} // namespace detail
5858

59+
/// Predefined constant_mask kinds.
60+
enum class ConstantMaskKind { AllFalse = 0, AllTrue };
61+
5962
/// Default callback to build a region with a 'vector.yield' terminator with no
6063
/// arguments.
6164
void buildTerminatedBody(OpBuilder &builder, Location loc);
@@ -168,6 +171,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
168171
SmallVector<arith::ConstantIndexOp>
169172
getAsConstantIndexOps(ArrayRef<Value> values);
170173

174+
/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
175+
/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
176+
/// `std::nullopt`.
177+
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
178+
171179
//===----------------------------------------------------------------------===//
172180
// Vector Masking Utilities
173181
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,6 +2364,11 @@ def Vector_ConstantMaskOp :
23642364
```
23652365
}];
23662366

2367+
let builders = [
2368+
// Build with mixed static/dynamic operands.
2369+
OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
2370+
];
2371+
23672372
let extraClassDeclaration = [{
23682373
/// Return the result type of this op.
23692374
VectorType getVectorType() {

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1313
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1415

1516
namespace mlir {
1617
class MLIRContext;
@@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
115116
MaskingOpInterface maskingOp,
116117
RewriterBase &rewriter);
117118

119+
// Structure to hold the range of `vector.vscale`.
120+
struct VscaleRange {
121+
unsigned vscaleMin;
122+
unsigned vscaleMax;
123+
};
124+
125+
/// Attempts to eliminate redundant vector masks by replacing them with all-true
126+
/// constants at the top of the function (which results in the masks folding
127+
/// away). Note: Currently, this only runs for vector.create_mask ops and
128+
/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
129+
/// nothing. This is because these redundant masks are much more likely for
130+
/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
131+
/// code has static sizes, so simpler folds remove the masks.
132+
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
133+
std::optional<VscaleRange> vscaleRange = {});
134+
118135
} // namespace vector
119136
} // namespace mlir
120137

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5776,6 +5776,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
57765776
// ConstantMaskOp
57775777
//===----------------------------------------------------------------------===//
57785778

5779+
void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
5780+
VectorType type, ConstantMaskKind kind) {
5781+
assert(kind == ConstantMaskKind::AllTrue ||
5782+
kind == ConstantMaskKind::AllFalse);
5783+
build(builder, result, type,
5784+
kind == ConstantMaskKind::AllTrue
5785+
? type.getShape()
5786+
: SmallVector<int64_t>(type.getRank(), 0));
5787+
}
5788+
57795789
LogicalResult ConstantMaskOp::verify() {
57805790
auto resultType = llvm::cast<VectorType>(getResult().getType());
57815791
// Check the corner case of 0-D vectors first.
@@ -5858,6 +5868,21 @@ LogicalResult CreateMaskOp::verify() {
58585868
return success();
58595869
}
58605870

5871+
std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
5872+
if (value.getDefiningOp<vector::VectorScaleOp>())
5873+
return 1;
5874+
auto mul = value.getDefiningOp<arith::MulIOp>();
5875+
if (!mul)
5876+
return {};
5877+
auto lhs = mul.getLhs();
5878+
auto rhs = mul.getRhs();
5879+
if (lhs.getDefiningOp<vector::VectorScaleOp>())
5880+
return getConstantIntValue(rhs);
5881+
if (rhs.getDefiningOp<vector::VectorScaleOp>())
5882+
return getConstantIntValue(lhs);
5883+
return {};
5884+
}
5885+
58615886
namespace {
58625887

58635888
/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5889,73 +5914,51 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
58895914

58905915
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
58915916
PatternRewriter &rewriter) const override {
5892-
VectorType retTy = createMaskOp.getResult().getType();
5893-
bool isScalable = retTy.isScalable();
5894-
5895-
// Check every mask operand
5896-
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5897-
if (auto cst = getConstantIntValue(operand)) {
5898-
// Most basic case - this operand is a constant value. Note that for
5899-
// scalable dimensions, CreateMaskOp can be folded only if the
5900-
// corresponding operand is negative or zero.
5901-
if (retTy.getScalableDims()[opIdx] && *cst > 0)
5902-
return failure();
5903-
5904-
continue;
5905-
}
5906-
5907-
// Non-constant operands are not allowed for non-scalable vectors.
5908-
if (!isScalable)
5909-
return failure();
5910-
5911-
// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5912-
// true" mask, so can also be treated as constant.
5913-
auto mul = operand.getDefiningOp<arith::MulIOp>();
5914-
if (!mul)
5915-
return failure();
5916-
auto mulLHS = mul.getRhs();
5917-
auto mulRHS = mul.getLhs();
5918-
bool isOneOpVscale =
5919-
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5920-
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5921-
5922-
auto isConstantValMatchingDim =
5923-
[=, dim = retTy.getShape()[opIdx]](Value operand) {
5924-
auto constantVal = getConstantIntValue(operand);
5925-
return (constantVal.has_value() && constantVal.value() == dim);
5926-
};
5927-
5928-
bool isOneOpConstantMatchingDim =
5929-
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5930-
5931-
if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5932-
return failure();
5917+
VectorType maskType = createMaskOp.getVectorType();
5918+
ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
5919+
ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
5920+
5921+
// Special case: Rank zero shape.
5922+
constexpr std::array<int64_t, 1> rankZeroShape{1};
5923+
constexpr std::array<bool, 1> rankZeroScalableDims{false};
5924+
if (maskType.getRank() == 0) {
5925+
maskTypeDimSizes = rankZeroShape;
5926+
maskTypeDimScalableFlags = rankZeroScalableDims;
59335927
}
59345928

5935-
// Gather constant mask dimension sizes.
5936-
SmallVector<int64_t, 4> maskDimSizes;
5937-
maskDimSizes.reserve(createMaskOp->getNumOperands());
5938-
for (auto [operand, maxDimSize] : llvm::zip_equal(
5939-
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5940-
std::optional dimSize = getConstantIntValue(operand);
5941-
if (!dimSize) {
5942-
// Although not a constant, it is safe to assume that `operand` is
5943-
// "vscale * maxDimSize".
5944-
maskDimSizes.push_back(maxDimSize);
5945-
continue;
5946-
}
5947-
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
5948-
// If one of dim sizes is zero, set all dims to zero.
5949-
if (dimSize <= 0) {
5950-
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5951-
break;
5929+
// Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
5930+
// collect the `constantDims` (for the ConstantMaskOp).
5931+
SmallVector<int64_t, 4> constantDims;
5932+
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
5933+
if (auto intSize = getConstantIntValue(dimSize)) {
5934+
// Constant value.
5935+
// If the mask dim is non-scalable this can be any value.
5936+
// If the mask dim is scalable only zero (all-false) is supported.
5937+
if (maskTypeDimScalableFlags[i] && intSize >= 0)
5938+
return failure();
5939+
constantDims.push_back(*intSize);
5940+
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
5941+
// Constant vscale multiple (e.g. 4 x vscale).
5942+
// Must be all-true to fold to a ConstantMask.
5943+
if (vscaleMultiplier < maskTypeDimSizes[i])
5944+
return failure();
5945+
constantDims.push_back(*vscaleMultiplier);
5946+
} else {
5947+
return failure();
59525948
}
5953-
maskDimSizes.push_back(dimSizeVal);
59545949
}
59555950

5951+
// Clamp values to constant_mask bounds.
5952+
for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
5953+
value = std::clamp<int64_t>(value, 0, maskDimSize);
5954+
5955+
// If one of dim sizes is zero, set all dims to zero.
5956+
if (llvm::is_contained(constantDims, 0))
5957+
constantDims.assign(constantDims.size(), 0);
5958+
59565959
// Replace 'createMaskOp' with ConstantMaskOp.
5957-
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
5958-
maskDimSizes);
5960+
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
5961+
constantDims);
59595962
return success();
59605963
}
59615964
};

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
2222
VectorTransferSplitRewritePatterns.cpp
2323
VectorTransforms.cpp
2424
VectorUnroll.cpp
25+
VectorMaskElimination.cpp
2526

2627
ADDITIONAL_HEADER_DIRS
2728
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
11+
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
12+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
13+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::vector;
18+
namespace {
19+
20+
/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
21+
/// All-true masks can then be eliminated by simple folds.
22+
LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
23+
vector::CreateMaskOp createMaskOp,
24+
VscaleRange vscaleRange) {
25+
auto maskType = createMaskOp.getVectorType();
26+
auto maskTypeDimScalableFlags = maskType.getScalableDims();
27+
auto maskTypeDimSizes = maskType.getShape();
28+
29+
struct UnknownMaskDim {
30+
size_t position;
31+
Value dimSize;
32+
};
33+
34+
// Loop over the CreateMaskOp operands and collect unknown dims (i.e. dims
35+
// that are not obviously constant). If any constant dimension is not all-true
36+
// bail out early (as this transform only trying to resolve all-true masks).
37+
// This avoids doing value-bounds anaylis in cases like:
38+
// `%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>`
39+
// ...where it is known the mask is not all-true by looking at `%c2`.
40+
SmallVector<UnknownMaskDim> unknownDims;
41+
for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
42+
if (auto intSize = getConstantIntValue(dimSize)) {
43+
// Mask not all-true for this dim.
44+
if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
45+
return failure();
46+
} else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
47+
// Mask not all-true for this dim.
48+
if (vscaleMultiplier < maskTypeDimSizes[i])
49+
return failure();
50+
} else {
51+
// Unknown (without further analysis).
52+
unknownDims.push_back(UnknownMaskDim{i, dimSize});
53+
}
54+
}
55+
56+
for (auto [i, dimSize] : unknownDims) {
57+
// Compute the lower bound for the unknown dimension (i.e. the smallest
58+
// value it could be).
59+
FailureOr<ConstantOrScalableBound> dimLowerBound =
60+
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
61+
dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
62+
presburger::BoundType::LB);
63+
if (failed(dimLowerBound))
64+
return failure();
65+
auto dimLowerBoundSize = dimLowerBound->getSize();
66+
if (failed(dimLowerBoundSize))
67+
return failure();
68+
if (dimLowerBoundSize->scalable) {
69+
// 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
70+
// this dim is not all-true.
71+
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
72+
return failure();
73+
} else {
74+
// 2. The lower bound, LB, is a constant.
75+
// - If the mask dim size is scalable then this dim is not all-true.
76+
if (maskTypeDimScalableFlags[i])
77+
return failure();
78+
// - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
79+
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
80+
return failure();
81+
}
82+
}
83+
84+
// Replace createMaskOp with an all-true constant. This should result in the
85+
// mask being removed in most cases (as xfer ops + vector.mask have folds to
86+
// remove all-true masks).
87+
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
88+
createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
89+
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
90+
return success();
91+
}
92+
93+
} // namespace
94+
95+
namespace mlir::vector {
96+
97+
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
98+
std::optional<VscaleRange> vscaleRange) {
99+
// TODO: Support fixed-size case. This is less likely to be useful as for
100+
// fixed-size code dimensions are all static so masks tend to fold away.
101+
if (!vscaleRange)
102+
return;
103+
104+
OpBuilder::InsertionGuard g(rewriter);
105+
106+
// Build worklist so we can safely insert new ops in
107+
// `resolveAllTrueCreateMaskOp()`.
108+
SmallVector<vector::CreateMaskOp> worklist;
109+
function.walk([&](vector::CreateMaskOp createMaskOp) {
110+
worklist.push_back(createMaskOp);
111+
});
112+
113+
rewriter.setInsertionPointToStart(&function.front());
114+
for (auto mask : worklist)
115+
(void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
116+
}
117+
118+
} // namespace mlir::vector

0 commit comments

Comments
 (0)