Skip to content

Commit c716558

Browse files
[mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)
This commit makes the following changes: - Expose `map` and `mapOperands` in `ValueBoundsConstraintSet::Variable`, so that the class can be used by subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot access those members. - Add `ValueBoundsConstraintSet::strongCompare`. This method is similar to `ValueBoundsConstraintSet::compare` except that it returns false when the inverse comparison holds, and `llvm::failure()` if neither the relation nor its inverse relation could be proven. - Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and `simplifyAffineMinMaxOps` to simplify those operations using `ValueBoundsConstraintSet`. - Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses `simplifyAffineMinMaxOps`. - Add the `test.value_with_bounds` op to test unknown values with a min max range using `ValueBoundsOpInterface`. - Adds tests verifying the transform. Example: ```mlir func.func @overlapping_constraints() -> (index, index) { %0 = test.value_with_bounds {min = 0 : index, max = 192 : index} %1 = test.value_with_bounds {min = 128 : index, max = 384 : index} %2 = test.value_with_bounds {min = 256 : index, max = 512 : index} %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2] %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2] return %r0, %r1 : index, index } // Result of applying `simplifyAffineMinMaxOps` to `func.func` #map1 = affine_map<()[s0, s1] -> (s1, s0)> func.func @overlapping_constraints() -> (index, index) { %0 = test.value_with_bounds {max = 192 : index, min = 0 : index} %1 = test.value_with_bounds {max = 384 : index, min = 128 : index} %2 = test.value_with_bounds {max = 512 : index, min = 256 : index} %3 = affine.min #map1()[%0, %1] %4 = affine.max #map1()[%1, %2] return %3, %4 : index, index } ``` --------- Co-authored-by: Nicolas Vasilache <[email protected]>
1 parent 89c6144 commit c716558

File tree

11 files changed

+493
-9
lines changed

11 files changed

+493
-9
lines changed

mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp
6363
}];
6464
}
6565

66+
def SimplifyMinMaxAffineOpsOp :
67+
Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
68+
DeclareOpInterfaceMethods<TransformOpInterface>,
69+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
70+
]> {
71+
let description = [{
72+
Simplify the targeted `affine.min` / `affine.max` ops using the
73+
`mlir::affine::simplifyAffineMinMaxOps` transform.
74+
75+
Example:
76+
```
77+
%0 = transform.structured.match ops{["affine.max"]} in %arg1
78+
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
79+
```
80+
81+
#### Return modes
82+
83+
This transform consumes the target handle and does not produce any results.
84+
This transforms definitely fails if any of the targeted operations is not an
85+
`affine.min` or `affine.max` operation, or if the canonicalization patterns
86+
failed to converge.
87+
This transform silently fails if none of the operations were simplified.
88+
Otherwise, it succeeds.
89+
}];
90+
let arguments = (ins TransformHandleTypeInterface:$target);
91+
let results = (outs);
92+
let assemblyFormat = [{
93+
$target attr-dict `:` type($target)
94+
}];
95+
}
96+
6697
#endif // Affine_TRANSFORM_OPS

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ namespace affine {
3434
class AffineApplyOp;
3535
class AffineDelinearizeIndexOp;
3636
class AffineLinearizeIndexOp;
37+
class AffineMaxOp;
38+
class AffineMinOp;
3739

3840
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
3941
/// operations.
@@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound(
127129
OpBuilder &b, Location loc, AffineMap boundMap,
128130
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
129131

132+
/// This transform tries to simplify the affine min operation `op`, by finding a
133+
/// common lower bound for a set of expressions in the affine map results. It
134+
/// returns whether the transform updated `op`'s affine map.
135+
///
136+
/// In concrete terms, given an operation like:
137+
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
138+
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
139+
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
140+
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
141+
142+
/// This transform tries to simplify the affine max operation `op`, by finding a
143+
/// common upper bound for a set of expressions in the affine map results. It
144+
/// returns whether the transform updated `op`'s affine map.
145+
///
146+
/// In concrete terms, given an operation like:
147+
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
148+
/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
149+
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
150+
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
151+
152+
/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
153+
/// all the `affine.min` or `affine.max` operations in `ops`. After
154+
/// simplification, it invokes the `affine.min/max` canonicalization patterns on
155+
/// `ops`.
156+
///
157+
/// This transform returns failure if the greedy pattern rewriter failed to
158+
/// converge during canonicalization, otherwise it returns success. If provided,
159+
/// `modified` is set to `true` if the IR was modified in any way.
160+
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
161+
ArrayRef<Operation *> ops,
162+
bool *modified = nullptr);
130163
} // namespace affine
131164
} // namespace mlir
132165

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
135135

136136
/// Construct a variable for a map and its operands.
137137
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
138-
Variable(AffineMap map, ArrayRef<Value> mapOperands);
138+
Variable(AffineMap map, ValueRange mapOperands);
139139

140140
MLIRContext *getContext() const { return map.getContext(); }
141141

142+
/// Returns the affine map.
143+
AffineMap getMap() const { return map; }
144+
145+
/// Returns the map operands.
146+
ValueDimList &getOperands() { return mapOperands; }
147+
const ValueDimList &getOperands() const { return mapOperands; }
148+
142149
private:
143150
friend class ValueBoundsConstraintSet;
144151
AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
254261
/// prove the relation or until it ran out of IR.
255262
static bool compare(const Variable &lhs, ComparisonOperator cmp,
256263
const Variable &rhs);
264+
/// This function is similar to `ValueBoundsConstraintSet::compare`, except
265+
/// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
266+
/// relation nor its inverse relation could be proven.
267+
static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
268+
ComparisonOperator cmp,
269+
const Variable &rhs);
257270

258271
/// Compute whether the given variables are equal. Return "failure" if
259272
/// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
327340
/// constraints.
328341
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
329342

343+
/// Return "true" if, based on the current state of the constraint system,
344+
/// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
345+
/// can be proven. Otherwise, it returns `failure` if neither the relation nor
346+
/// its inverse relation could be proven.
347+
///
348+
/// This function does not analyze any IR and does not populate any additional
349+
/// constraints.
350+
llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
351+
int64_t rhsPos);
352+
330353
/// Given an affine map with a single result (and map operands), add a new
331354
/// column to the constraint set that represents the result of the map.
332355
/// Traverse additional IR starting from the map operands as needed (as long

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1313
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
1414
#include "mlir/Dialect/Affine/LoopUtils.h"
15+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1516
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1617
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1718
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
112113
}
113114
if (boundedOps.contains(target)) {
114115
auto diag = emitDefiniteFailure()
115-
<< "target op result must not be constrainted";
116+
<< "target op result must not be constrained";
116117
diag.attachNote(target->getLoc()) << "target/constrained op";
117118
return diag;
118119
}
@@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects(
148149
modifiesPayload(effects);
149150
}
150151

152+
//===----------------------------------------------------------------------===//
153+
// SimplifyMinMaxAffineOpsOp
154+
//===----------------------------------------------------------------------===//
155+
DiagnosedSilenceableFailure
156+
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
157+
TransformResults &results,
158+
TransformState &state) {
159+
SmallVector<Operation *> targets;
160+
for (Operation *target : state.getPayloadOps(getTarget())) {
161+
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
162+
auto diag = emitDefiniteFailure()
163+
<< "target must be affine.min or affine.max";
164+
diag.attachNote(target->getLoc()) << "target op";
165+
return diag;
166+
}
167+
targets.push_back(target);
168+
}
169+
bool modified = false;
170+
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
171+
&modified))) {
172+
return emitDefiniteFailure()
173+
<< "affine.min/max simplification did not converge";
174+
}
175+
if (!modified) {
176+
return emitSilenceableError()
177+
<< "the transform failed to simplify any of the target operations";
178+
}
179+
return DiagnosedSilenceableFailure::success();
180+
}
181+
182+
void SimplifyMinMaxAffineOpsOp::getEffects(
183+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184+
consumesHandle(getTargetMutable(), effects);
185+
modifiesPayload(effects);
186+
}
187+
151188
//===----------------------------------------------------------------------===//
152189
// Transform op registration
153190
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
1717
ReifyValueBounds.cpp
1818
SuperVectorize.cpp
1919
SimplifyAffineStructures.cpp
20+
SimplifyAffineMinMax.cpp
2021

2122
ADDITIONAL_HEADER_DIRS
2223
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a transform to simplify mix/max affine operations.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
14+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include "llvm/ADT/IntEqClasses.h"
19+
#include "llvm/Support/Debug.h"
20+
21+
#define DEBUG_TYPE "affine-min-max"
22+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
23+
24+
using namespace mlir;
25+
using namespace mlir::affine;
26+
27+
/// Simplifies an affine min/max operation by proving there's a lower or upper
28+
/// bound.
29+
template <typename AffineOp>
30+
static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
31+
using Variable = ValueBoundsConstraintSet::Variable;
32+
using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
33+
34+
AffineMap affineMap = affineOp.getMap();
35+
ValueRange operands = affineOp.getOperands();
36+
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
37+
38+
LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
39+
40+
// Create a `Variable` list with values corresponding to each of the results
41+
// in the affine affineMap.
42+
SmallVector<Variable> variables = llvm::map_to_vector(
43+
llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
44+
[&](unsigned i) {
45+
return Variable(affineMap.getSliceMap(i, 1), operands);
46+
});
47+
48+
// Get the comparison operation.
49+
ComparisonOperator cmpOp =
50+
isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
51+
52+
// Find disjoint sets bounded by a common value.
53+
llvm::IntEqClasses boundedClasses(variables.size());
54+
DenseMap<unsigned, Variable *> bounds;
55+
for (auto &&[i, v] : llvm::enumerate(variables)) {
56+
unsigned eqClass = boundedClasses.findLeader(i);
57+
58+
// If the class already has a bound continue.
59+
if (bounds.contains(eqClass))
60+
continue;
61+
62+
// Initialize the bound.
63+
Variable *bound = &v;
64+
65+
LLVM_DEBUG({
66+
DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
67+
<< "`\n";
68+
});
69+
70+
// Check against the other variables.
71+
for (size_t j = i + 1; j < variables.size(); ++j) {
72+
unsigned jEqClass = boundedClasses.findLeader(j);
73+
// Skip if the class is the same.
74+
if (jEqClass == eqClass)
75+
continue;
76+
77+
// Get the bound of the equivalence class or itself.
78+
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
79+
80+
LLVM_DEBUG({
81+
DBGS() << "- comparing with variable: #" << jEqClass
82+
<< ", with map: " << nv->getMap() << "\n";
83+
});
84+
85+
// Compare the variables.
86+
FailureOr<bool> cmpResult =
87+
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
88+
89+
// The variables cannot be compared.
90+
if (failed(cmpResult)) {
91+
LLVM_DEBUG({
92+
DBGS() << "-- classes: #" << i << ", #" << jEqClass
93+
<< " cannot be merged\n";
94+
});
95+
continue;
96+
}
97+
98+
// Join the equivalent classes and update the bound if necessary.
99+
LLVM_DEBUG({
100+
DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
101+
<< ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
102+
});
103+
if (*cmpResult) {
104+
boundedClasses.join(eqClass, jEqClass);
105+
} else {
106+
// In this case we have lhs > rhs if isMin == true, or lhs < rhs if
107+
// isMin == false.
108+
bound = nv;
109+
boundedClasses.join(eqClass, jEqClass);
110+
}
111+
}
112+
bounds[boundedClasses.findLeader(i)] = bound;
113+
}
114+
115+
// Return if there's no simplification.
116+
if (bounds.size() >= affineMap.getNumResults()) {
117+
LLVM_DEBUG(
118+
{ DBGS() << "- the affine operation couldn't get simplified\n"; });
119+
return false;
120+
}
121+
122+
// Construct the new affine affineMap.
123+
SmallVector<AffineExpr> results;
124+
results.reserve(bounds.size());
125+
for (auto [k, bound] : bounds)
126+
results.push_back(bound->getMap().getResult(0));
127+
128+
affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
129+
results, rewriter.getContext());
130+
131+
// Update the affine op.
132+
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
133+
LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
134+
return true;
135+
}
136+
137+
bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
138+
return simplifyAffineMinMaxOp(rewriter, op);
139+
}
140+
141+
bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
142+
return simplifyAffineMinMaxOp(rewriter, op);
143+
}
144+
145+
LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
146+
ArrayRef<Operation *> ops,
147+
bool *modified) {
148+
bool changed = false;
149+
for (Operation *op : ops) {
150+
if (auto minOp = dyn_cast<AffineMinOp>(op))
151+
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
152+
else if (auto maxOp = cast<AffineMaxOp>(op))
153+
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
154+
}
155+
RewritePatternSet patterns(rewriter.getContext());
156+
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
157+
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
158+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
159+
if (modified)
160+
*modified = changed;
161+
// Canonicalize to a fixpoint.
162+
if (failed(applyOpPatternsGreedily(
163+
ops, frozenPatterns,
164+
GreedyRewriteConfig()
165+
.setListener(
166+
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
167+
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
168+
&changed))) {
169+
return failure();
170+
}
171+
if (modified)
172+
*modified = changed;
173+
return success();
174+
}

0 commit comments

Comments
 (0)