Skip to content

Commit a753045

Browse files
[mlir][transform] Improve error message of tracking listener. (#66987)
This PR extends the error message of the tracking listener when replacement ops cannot be found. That may happen if the applied patterns replace an op by an op of a different kind or by block arguments. However, this only matters if there are alive handles to the replaced op. The new error message mentions that explicitly and reports the alive handles.
1 parent 1f64dc8 commit a753045

File tree

4 files changed

+81
-42
lines changed

4 files changed

+81
-42
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,8 @@ class TransformResults {
774774
/// corresponds to the given list of payload IR ops. Each result must be set
775775
/// by the transformation exactly once in case of transformation succeeding.
776776
/// The value must have a type implementing TransformHandleTypeInterface.
777-
template <typename Range> void set(OpResult value, Range &&ops) {
777+
template <typename Range>
778+
void set(OpResult value, Range &&ops) {
778779
int64_t position = value.getResultNumber();
779780
assert(position < static_cast<int64_t>(operations.size()) &&
780781
"setting results for a non-existent handle");
@@ -929,8 +930,9 @@ class TrackingListener : public RewriterBase::Listener,
929930
///
930931
/// Derived classes may override `findReplacementOp` to specify custom
931932
/// replacement rules.
932-
virtual FailureOr<Operation *> findReplacementOp(Operation *op,
933-
ValueRange newValues) const;
933+
virtual DiagnosedSilenceableFailure
934+
findReplacementOp(Operation *&result, Operation *op,
935+
ValueRange newValues) const;
934936

935937
/// Notify the listener that the pattern failed to match the given operation,
936938
/// and provide a callback to populate a diagnostic with the reason why the
@@ -942,8 +944,9 @@ class TrackingListener : public RewriterBase::Listener,
942944
/// This function is called when a tracked payload op is dropped because no
943945
/// replacement op was found. Derived classes can implement this function for
944946
/// custom error handling.
945-
virtual void notifyPayloadReplacementNotFound(Operation *op,
946-
ValueRange values) {}
947+
virtual void
948+
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
949+
DiagnosedSilenceableFailure &&diag) {}
947950

948951
/// Return the single op that defines all given values (if any).
949952
static Operation *getCommonDefiningOp(ValueRange values);
@@ -983,8 +986,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
983986
bool failed() const;
984987

985988
protected:
986-
void notifyPayloadReplacementNotFound(Operation *op,
987-
ValueRange values) override;
989+
void
990+
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
991+
DiagnosedSilenceableFailure &&diag) override;
988992

989993
private:
990994
/// The error state of this listener. "Success" indicates that no error

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,45 +1289,59 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
12891289
return defOp;
12901290
}
12911291

1292-
FailureOr<Operation *>
1293-
transform::TrackingListener::findReplacementOp(Operation *op,
1294-
ValueRange newValues) const {
1292+
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1293+
Operation *&result, Operation *op, ValueRange newValues) const {
12951294
assert(op->getNumResults() == newValues.size() &&
12961295
"invalid number of replacement values");
12971296
SmallVector<Value> values(newValues.begin(), newValues.end());
12981297

1298+
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1299+
getTransformOp(), "tracking listener failed to find replacement op "
1300+
"during application of this transform op");
1301+
12991302
do {
13001303
// If the replacement values belong to different ops, drop the mapping.
13011304
Operation *defOp = getCommonDefiningOp(values);
1302-
if (!defOp)
1303-
return failure();
1305+
if (!defOp) {
1306+
diag.attachNote() << "replacement values belong to different ops";
1307+
return diag;
1308+
}
13041309

13051310
// If the defining op has the same type, we take it as a replacement.
1306-
if (op->getName() == defOp->getName())
1307-
return defOp;
1311+
if (op->getName() == defOp->getName()) {
1312+
result = defOp;
1313+
return DiagnosedSilenceableFailure::success();
1314+
}
13081315

13091316
// Replacing an op with a constant-like equivalent is a common
13101317
// canonicalization.
1311-
if (defOp->hasTrait<OpTrait::ConstantLike>())
1312-
return defOp;
1318+
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1319+
result = defOp;
1320+
return DiagnosedSilenceableFailure::success();
1321+
}
13131322

13141323
values.clear();
13151324

13161325
// Skip through ops that implement FindPayloadReplacementOpInterface.
13171326
if (auto findReplacementOpInterface =
13181327
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
13191328
values.assign(findReplacementOpInterface.getNextOperands());
1329+
diag.attachNote(defOp->getLoc()) << "using operands provided by "
1330+
"'FindPayloadReplacementOpInterface'";
13201331
continue;
13211332
}
13221333

13231334
// Skip through ops that implement CastOpInterface.
13241335
if (isa<CastOpInterface>(defOp)) {
13251336
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1337+
diag.attachNote(defOp->getLoc())
1338+
<< "using output of 'CastOpInterface' op";
13261339
continue;
13271340
}
13281341
} while (!values.empty());
13291342

1330-
return failure();
1343+
diag.attachNote() << "ran out of suitable replacement values";
1344+
return diag;
13311345
}
13321346

13331347
LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1396,32 +1410,39 @@ void transform::TrackingListener::notifyOperationReplaced(
13961410
};
13971411

13981412
// Helper function to check if the handle is alive.
1399-
auto hasAliveUser = [&]() {
1413+
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
14001414
for (Value v : opHandles) {
1401-
for (Operation *user : v.getUsers())
1402-
if (user != transformOp && !happensBefore(user, transformOp))
1403-
return true;
1415+
for (OpOperand &use : v.getUses())
1416+
if (use.getOwner() != transformOp &&
1417+
!happensBefore(use.getOwner(), transformOp))
1418+
return &use;
14041419
}
1405-
return false;
1406-
};
1420+
return std::nullopt;
1421+
}();
14071422

1408-
if (!hasAliveUser() || handleWasConsumed()) {
1423+
if (!firstAliveUser.has_value() || handleWasConsumed()) {
14091424
// The op is tracked but the corresponding handles are dead or were
14101425
// consumed. Drop the op form the mapping.
14111426
(void)replacePayloadOp(op, nullptr);
14121427
return;
14131428
}
14141429

1415-
FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
1430+
Operation *replacement;
1431+
DiagnosedSilenceableFailure diag =
1432+
findReplacementOp(replacement, op, newValues);
14161433
// If the op is tracked but no replacement op was found, send a
14171434
// notification.
1418-
if (failed(replacement)) {
1419-
notifyPayloadReplacementNotFound(op, newValues);
1435+
if (!diag.succeeded()) {
1436+
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
1437+
<< "replacement is required because alive handle(s) exist "
1438+
<< "(first use in this op as operand number "
1439+
<< (*firstAliveUser)->getOperandNumber() << ")";
1440+
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
14201441
(void)replacePayloadOp(op, nullptr);
14211442
return;
14221443
}
14231444

1424-
(void)replacePayloadOp(op, *replacement);
1445+
(void)replacePayloadOp(op, replacement);
14251446
}
14261447

14271448
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1444,17 +1465,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
14441465
}
14451466

14461467
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1447-
Operation *op, ValueRange values) {
1448-
if (status.succeeded()) {
1449-
status = emitSilenceableFailure(
1450-
getTransformOp(), "tracking listener failed to find replacement op");
1451-
}
1468+
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1469+
1470+
// Merge potentially existing diags and store the result in the listener.
1471+
SmallVector<Diagnostic> diags;
1472+
diag.takeDiagnostics(diags);
1473+
if (!status.succeeded())
1474+
status.takeDiagnostics(diags);
1475+
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
14521476

1477+
// Report more details.
14531478
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
14541479
for (auto &&[index, value] : llvm::enumerate(values))
14551480
status.attachNote(value.getLoc())
14561481
<< "[" << errorCounter << "] replacement value " << index;
1457-
14581482
++errorCounter;
14591483
}
14601484

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ transform.sequence failures(propagate) {
3737
^bb1(%arg1: !transform.any_op):
3838
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
3939
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
40-
// expected-error @below {{tracking listener failed to find replacement op}}
40+
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
41+
// expected-note @below {{ran out of suitable replacement values}}
4142
transform.apply_patterns to %0 {
4243
transform.apply_patterns.transform.test_patterns
4344
} : !transform.any_op
4445
// %1 must be used in some way. If no replacement payload op could be found,
4546
// an error is thrown only if the handle is not dead.
47+
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
4648
transform.annotate %1 "annotated" : !transform.any_op
4749
}
4850

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ struct TestTensorTransforms
3232
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
3333

3434
void getDependentDialects(DialectRegistry &registry) const override {
35-
registry
36-
.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
35+
registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36+
transform::TransformDialect>();
3737
}
3838

3939
StringRef getArgument() const final {
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
292292

293293
// Expose `findReplacementOp` as a public function, so that it can be tested.
294294
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
295-
FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
296-
if (failed(replacementOp))
295+
Operation *replacementOp;
296+
if (!findReplacementOp(replacementOp, op, newValues).succeeded())
297297
return nullptr;
298-
return *replacementOp;
298+
return replacementOp;
299299
}
300300
};
301301
} // namespace
@@ -352,8 +352,17 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
352352
transform::TransformState transformState =
353353
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
354354
/*payloadRoot=*/nullptr);
355-
DummyTrackingListener listener(transformState,
356-
transform::TransformOpInterface());
355+
MLIRContext *context = rootOp->getContext();
356+
OpBuilder builder(context);
357+
auto transformOp = builder.create<transform::NamedSequenceOp>(
358+
rootOp->getLoc(),
359+
/*sym_name=*/"test_sequence",
360+
/*function_type=*/
361+
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
362+
/*sym_visibility*/ StringAttr::get(context, "public"),
363+
/*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
364+
/*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
365+
DummyTrackingListener listener(transformState, transformOp);
357366
Operation *replacement = listener.getReplacementOp(replaced, replacements);
358367
if (!replacement) {
359368
replaced->emitError("listener could not find replacement op");

0 commit comments

Comments
 (0)