Skip to content

Commit ac858bb

Browse files
committed
address reviewer comments
1 parent f23ef73 commit ac858bb

File tree

6 files changed

+101
-57
lines changed

6 files changed

+101
-57
lines changed

mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,38 +65,33 @@ def SimplifyBoundedAffineOpsOp
6565

6666
def SimplifyMinMaxAffineOpsOp :
6767
Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
68-
TransformOpInterface,
69-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
70-
TransformEachOpTrait
68+
DeclareOpInterfaceMethods<TransformOpInterface>,
69+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
7170
]> {
7271
let description = [{
73-
Simplify all the affine.min / affine.max ops being targeted or nested in the
74-
target operation, using the `mlir::affine::simplifyAffineMinMaxOps`
75-
transform.
72+
Simplify the targeted `affine.min` / `affine.max` ops using the
73+
`mlir::affine::simplifyAffineMinMaxOps` transform.
7674

7775
Example:
7876
```
79-
%0 = transform.structured.match ops{["gpu.launch", "affine.max"]} in %arg1
77+
%0 = transform.structured.match ops{["affine.max"]} in %arg1
8078
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
8179
```
8280

8381
#### Return modes
8482

8583
This transform consumes the target handle and does not produce any results.
86-
This transforms never produces errors.
84+
This transforms definitely fails if any of the targeted operations is not an
85+
`affine.min` or `affine.max` operation, or if the canonicalization patterns
86+
failed to converge.
87+
This transform silently fails if none of the operations were simplified.
88+
Otherwise, it returns success.
8789
}];
8890
let arguments = (ins TransformHandleTypeInterface:$target);
8991
let results = (outs);
9092
let assemblyFormat = [{
9193
$target attr-dict `:` type($target)
9294
}];
93-
let extraClassDeclaration = [{
94-
::mlir::DiagnosedSilenceableFailure applyToOne(
95-
::mlir::transform::TransformRewriter &rewriter,
96-
::mlir::Operation *target,
97-
::mlir::transform::ApplyToEachResultList &results,
98-
::mlir::transform::TransformState &state);
99-
}];
10095
}
10196

10297
#endif // Affine_TRANSFORM_OPS

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,23 +129,37 @@ OpFoldResult materializeComputedBound(
129129
OpBuilder &b, Location loc, AffineMap boundMap,
130130
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
131131

132-
/// Tries to simplify all affine min or max operations under `topOp`. The
133-
/// transform works by finding disjoint sets of affine result expressions
134-
/// bounded by a common affine expression on the min/max operation. It populates
135-
/// `modifiedOps` with all the operations modified by the transform.
132+
/// This transform tries to simplify the affine min operation `op`, by finding a
133+
/// common lower bound for a set of expressions in the affine map results. It
134+
/// returns whether the transform updated `op`'s affine map.
136135
///
137136
/// In concrete terms, given an operation like:
138137
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
139-
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update the op to:
138+
/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
140139
/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
141-
void simplifyAffineMinMaxOps(RewriterBase &rewriter, Operation *topOp,
142-
SmallVectorImpl<Operation *> &modifiedOps);
143-
/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
144-
/// the operation was modified.
145140
bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
146-
/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
147-
/// the operation was modified.
141+
142+
/// This transform tries to simplify the affine max operation `op`, by finding a
143+
/// common upper bound for a set of expressions in the affine map results. It
144+
/// returns whether the transform updated `op`'s affine map.
145+
///
146+
/// In concrete terms, given an operation like:
147+
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
148+
/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
149+
/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
148150
bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
151+
152+
/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
153+
/// all the `affine.min` or `affine.max` operations in `ops`. After
154+
/// simplification, it invokes the `affine.min/max` canonicalization patterns on
155+
/// `ops`.
156+
///
157+
/// This transform returns failure if the greedy pattern rewriter failed to
158+
/// converge during canonicalization, otherwise it returns success. If provided,
159+
/// `modified` is set to `true` if the IR was modified in any way.
160+
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
161+
ArrayRef<Operation *> ops,
162+
bool *modified = nullptr);
149163
} // namespace affine
150164
} // namespace mlir
151165

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
113113
}
114114
if (boundedOps.contains(target)) {
115115
auto diag = emitDefiniteFailure()
116-
<< "target op result must not be constrainted";
116+
<< "target op result must not be constrained";
117117
diag.attachNote(target->getLoc()) << "target/constrained op";
118118
return diag;
119119
}
@@ -152,12 +152,30 @@ void SimplifyBoundedAffineOpsOp::getEffects(
152152
//===----------------------------------------------------------------------===//
153153
// SimplifyMinMaxAffineOpsOp
154154
//===----------------------------------------------------------------------===//
155-
156-
DiagnosedSilenceableFailure SimplifyMinMaxAffineOpsOp::applyToOne(
157-
TransformRewriter &rewriter, Operation *target,
158-
ApplyToEachResultList &results, TransformState &state) {
159-
SmallVector<Operation *> modifiedOps;
160-
simplifyAffineMinMaxOps(rewriter, target, modifiedOps);
155+
DiagnosedSilenceableFailure
156+
SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
157+
TransformResults &results,
158+
TransformState &state) {
159+
SmallVector<Operation *> targets;
160+
for (Operation *target : state.getPayloadOps(getTarget())) {
161+
if (!isa<AffineMinOp, AffineMaxOp>(target)) {
162+
auto diag = emitDefiniteFailure()
163+
<< "target must be affine.min or affine.max";
164+
diag.attachNote(target->getLoc()) << "target op";
165+
return diag;
166+
}
167+
targets.push_back(target);
168+
}
169+
bool modified = false;
170+
if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
171+
&modified))) {
172+
return emitDefiniteFailure()
173+
<< "affine.min/max simplification did not converge";
174+
}
175+
if (!modified) {
176+
return emitSilenceableError()
177+
<< "the transform failed to simplify any of the target operations";
178+
}
161179
return DiagnosedSilenceableFailure::success();
162180
}
163181

mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1515
#include "mlir/IR/PatternMatch.h"
1616
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1718
#include "llvm/ADT/IntEqClasses.h"
1819
#include "llvm/Support/Debug.h"
1920

@@ -69,6 +70,10 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
6970
// Check against the other variables.
7071
for (size_t j = i + 1; j < variables.size(); ++j) {
7172
unsigned jEqClass = boundedClasses.findLeader(j);
73+
// Skip if the class is the same.
74+
if (jEqClass == eqClass)
75+
continue;
76+
7277
// Get the bound of the equivalence class or itself.
7378
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
7479

@@ -93,7 +98,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
9398
// Join the equivalent classes and update the bound if necessary.
9499
LLVM_DEBUG({
95100
DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
96-
<< ", is lhs <= rhs: " << *cmpResult << "`\n";
101+
<< ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
97102
});
98103
if (*cmpResult) {
99104
boundedClasses.join(eqClass, jEqClass);
@@ -137,17 +142,33 @@ bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
137142
return simplifyAffineMinMaxOp(rewriter, op);
138143
}
139144

140-
void mlir::affine::simplifyAffineMinMaxOps(
141-
RewriterBase &rewriter, Operation *topOp,
142-
SmallVectorImpl<Operation *> &modifiedOps) {
143-
assert(topOp && "null-op");
144-
topOp->walk([&](Operation *op) {
145-
if (auto affineOp = dyn_cast<AffineMinOp>(op)) {
146-
if (simplifyAffineMinMaxOp(rewriter, affineOp))
147-
modifiedOps.push_back(op);
148-
} else if (auto affineOp = dyn_cast<AffineMaxOp>(op)) {
149-
if (simplifyAffineMinMaxOp(rewriter, affineOp))
150-
modifiedOps.push_back(op);
151-
}
152-
});
145+
LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
146+
ArrayRef<Operation *> ops,
147+
bool *modified) {
148+
bool changed = false;
149+
for (Operation *op : ops) {
150+
if (auto minOp = dyn_cast<AffineMinOp>(op))
151+
changed = simplifyAffineMinOp(rewriter, minOp) || changed;
152+
else if (auto maxOp = cast<AffineMaxOp>(op))
153+
changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
154+
}
155+
RewritePatternSet patterns(rewriter.getContext());
156+
AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
157+
AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
158+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
159+
if (modified)
160+
*modified = changed;
161+
// Apply the simplification pattern to a fixpoint.
162+
if (failed(applyOpPatternsGreedily(
163+
ops, frozenPatterns,
164+
GreedyRewriteConfig()
165+
.setListener(
166+
static_cast<RewriterBase::Listener *>(rewriter.getListener()))
167+
.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
168+
&changed))) {
169+
return failure();
170+
}
171+
if (modified)
172+
*modified = changed;
173+
return success();
153174
}

mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,8 @@ func.func @overlapping_constraints() -> (index, index) {
6161

6262
module attributes {transform.with_named_sequence} {
6363
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
64-
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
64+
%0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
6565
transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
66-
%1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
67-
transform.apply_patterns to %1 {
68-
transform.apply_patterns.canonicalization
69-
} {apply_cse} : !transform.any_op
7066
transform.yield
7167
}
7268
}

mlir/test/Dialect/Linalg/transform-op-pad.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ module attributes {transform.with_named_sequence} {
464464
// CHECK: %[[LHS:.*]] = tensor.pad
465465
// CHECK: %[[RHS:.*]] = tensor.pad
466466
// CHECK: scf.for
467-
// CHECK: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
468-
// CHECK: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
467+
// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
468+
// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
469469
func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
470470
%0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
471471
return %0 : tensor<?x?xf32>
@@ -480,9 +480,9 @@ module attributes {transform.with_named_sequence} {
480480
transform.apply_patterns to %2 {
481481
transform.apply_patterns.canonicalization
482482
} {apply_cse} : !transform.any_op
483-
transform.affine.simplify_min_max_affine_ops %2 : !transform.any_op
484-
%3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
485-
transform.apply_patterns to %3 {
483+
%3 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
484+
transform.affine.simplify_min_max_affine_ops %3 : !transform.any_op
485+
transform.apply_patterns to %2 {
486486
transform.apply_patterns.canonicalization
487487
} {apply_cse} : !transform.any_op
488488
transform.yield

0 commit comments

Comments
 (0)