Skip to content

Commit 6c57b0d

Browse files
committed
[mlir] improve and test TransformState::Extension
Add the mechanism for TransformState extensions to update the mapping between Transform IR values and Payload IR operations held by the state. The mechanism is intentionally restrictive, similarly to how results of the transform op are handled. Introduce test ops that exercise a simple extension that maintains information across the application of multiple transform ops. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D124778
1 parent ad47114 commit 6c57b0d

File tree

6 files changed

+236
-17
lines changed

6 files changed

+236
-17
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ class TransformState {
7474
/// This is helpful for transformations that apply to a particular handle.
7575
ArrayRef<Operation *> getPayloadOps(Value value) const;
7676

77+
/// Returns the Transform IR handle for the given Payload IR op if it exists
78+
/// in the state, null otherwise.
79+
Value getHandleForPayloadOp(Operation *op) const;
80+
7781
/// Applies the transformation specified by the given transform op and updates
7882
/// the state accordingly.
7983
LogicalResult applyTransform(TransformOpInterface transform);
@@ -185,6 +189,10 @@ class TransformState {
185189
/// Provides read-only access to the parent TransformState object.
186190
const TransformState &getTransformState() const { return state; }
187191

192+
/// Replaces the given payload op with another op. If the replacement op is
193+
/// null, removes the association of the payload op with its handle.
194+
LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
195+
188196
private:
189197
/// Back-reference to the state that is being extended.
190198
TransformState &state;
@@ -276,9 +284,17 @@ class TransformState {
276284
/// The callback function is called once per associated operation and is
277285
/// expected to return the modified operation or nullptr. In the latter case,
278286
/// the corresponding operation is no longer associated with the transform IR
279-
/// value.
280-
void updatePayloadOps(Value value,
281-
function_ref<Operation *(Operation *)> callback);
287+
/// value. May fail if the operation produced by the update callback is
288+
/// already associated with a different Transform IR handle value.
289+
LogicalResult
290+
updatePayloadOps(Value value,
291+
function_ref<Operation *(Operation *)> callback);
292+
293+
/// Attempts to record the mapping between the given Payload IR operation and
294+
/// the given Transform IR handle. Fails and reports an error if the operation
295+
/// is already tracked by another handle.
296+
static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
297+
Value handle);
282298

283299
/// The mappings between transform IR values and payload IR ops, aggregated by
284300
/// the region in which the transform IR values are defined.

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@ transform::TransformState::getPayloadOps(Value value) const {
4141
return iter->getSecond();
4242
}
4343

44+
Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
45+
for (const Mappings &mapping : llvm::make_second_range(mappings)) {
46+
if (Value handle = mapping.reverse.lookup(op))
47+
return handle;
48+
}
49+
return Value();
50+
}
51+
52+
LogicalResult transform::TransformState::tryEmplaceReverseMapping(
53+
Mappings &map, Operation *operation, Value handle) {
54+
auto insertionResult = map.reverse.insert({operation, handle});
55+
if (!insertionResult.second) {
56+
InFlightDiagnostic diag = operation->emitError()
57+
<< "operation tracked by two handles";
58+
diag.attachNote(handle.getLoc()) << "handle";
59+
diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
60+
return diag;
61+
}
62+
return success();
63+
}
64+
4465
LogicalResult
4566
transform::TransformState::setPayloadOps(Value value,
4667
ArrayRef<Operation *> targets) {
@@ -63,14 +84,8 @@ transform::TransformState::setPayloadOps(Value value,
6384
// expressed using the dialect and may be constructed by valid API calls from
6485
// valid IR. Emit an error here.
6586
for (Operation *op : targets) {
66-
auto insertionResult = mappings.reverse.insert({op, value});
67-
if (!insertionResult.second) {
68-
InFlightDiagnostic diag = op->emitError()
69-
<< "operation tracked by two handles";
70-
diag.attachNote(value.getLoc()) << "handle";
71-
diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
72-
return diag;
73-
}
87+
if (failed(tryEmplaceReverseMapping(mappings, op, value)))
88+
return failure();
7489
}
7590

7691
return success();
@@ -83,19 +98,26 @@ void transform::TransformState::removePayloadOps(Value value) {
8398
mappings.direct.erase(value);
8499
}
85100

86-
void transform::TransformState::updatePayloadOps(
101+
LogicalResult transform::TransformState::updatePayloadOps(
87102
Value value, function_ref<Operation *(Operation *)> callback) {
88-
auto it = getMapping(value).direct.find(value);
89-
assert(it != getMapping(value).direct.end() && "unknown handle");
103+
Mappings &mappings = getMapping(value);
104+
auto it = mappings.direct.find(value);
105+
assert(it != mappings.direct.end() && "unknown handle");
90106
SmallVector<Operation *> &association = it->getSecond();
91107
SmallVector<Operation *> updated;
92108
updated.reserve(association.size());
93109

94-
for (Operation *op : association)
95-
if (Operation *updatedOp = callback(op))
110+
for (Operation *op : association) {
111+
mappings.reverse.erase(op);
112+
if (Operation *updatedOp = callback(op)) {
96113
updated.push_back(updatedOp);
114+
if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
115+
return failure();
116+
}
117+
}
97118

98119
std::swap(association, updated);
120+
return success();
99121
}
100122

101123
LogicalResult
@@ -132,8 +154,21 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
132154
return success();
133155
}
134156

157+
//===----------------------------------------------------------------------===//
158+
// TransformState::Extension
159+
//===----------------------------------------------------------------------===//
160+
135161
transform::TransformState::Extension::~Extension() = default;
136162

163+
LogicalResult
164+
transform::TransformState::Extension::replacePayloadOp(Operation *op,
165+
Operation *replacement) {
166+
return state.updatePayloadOps(state.getHandleForPayloadOp(op),
167+
[&](Operation *current) {
168+
return current == op ? replacement : current;
169+
});
170+
}
171+
137172
//===----------------------------------------------------------------------===//
138173
// TransformResults
139174
//===----------------------------------------------------------------------===//
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -split-input-file
2+
3+
// expected-note @below {{associated payload op}}
4+
module {
5+
transform.sequence {
6+
^bb0(%arg0: !pdl.operation):
7+
// expected-remark @below {{extension absent}}
8+
test_check_if_test_extension_present %arg0
9+
test_add_test_extension "A"
10+
// expected-remark @below {{extension present, A}}
11+
test_check_if_test_extension_present %arg0
12+
test_remove_test_extension
13+
// expected-remark @below {{extension absent}}
14+
test_check_if_test_extension_present %arg0
15+
}
16+
}
17+
18+
// -----
19+
20+
// expected-note @below {{associated payload op}}
21+
module {
22+
transform.sequence {
23+
^bb0(%arg0: !pdl.operation):
24+
test_add_test_extension "A"
25+
test_remove_test_extension
26+
test_add_test_extension "B"
27+
// expected-remark @below {{extension present, B}}
28+
test_check_if_test_extension_present %arg0
29+
}
30+
}
31+
32+
// -----
33+
34+
// expected-note @below {{associated payload op}}
35+
module {
36+
transform.sequence {
37+
^bb0(%arg0: !pdl.operation):
38+
test_add_test_extension "A"
39+
// expected-remark @below {{extension present, A}}
40+
test_check_if_test_extension_present %arg0
41+
// expected-note @below {{associated payload op}}
42+
test_remap_operand_to_self %arg0
43+
// expected-remark @below {{extension present, A}}
44+
test_check_if_test_extension_present %arg0
45+
}
46+
}

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "TestTransformDialectExtension.h"
15+
#include "TestTransformStateExtension.h"
1516
#include "mlir/Dialect/PDL/IR/PDL.h"
1617
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1718
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18-
#include "mlir/IR/Builders.h"
1919
#include "mlir/IR/OpImplementation.h"
2020

2121
using namespace mlir;
@@ -142,6 +142,49 @@ LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
142142
return success();
143143
}
144144

145+
LogicalResult
146+
mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
147+
transform::TransformState &state) {
148+
state.addExtension<TestTransformStateExtension>(getMessageAttr());
149+
return success();
150+
}
151+
152+
LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
153+
transform::TransformResults &results, transform::TransformState &state) {
154+
auto *extension = state.getExtension<TestTransformStateExtension>();
155+
if (!extension) {
156+
emitRemark() << "extension absent";
157+
return success();
158+
}
159+
160+
InFlightDiagnostic diag = emitRemark()
161+
<< "extension present, " << extension->getMessage();
162+
for (Operation *payload : state.getPayloadOps(getOperand())) {
163+
diag.attachNote(payload->getLoc()) << "associated payload op";
164+
assert(state.getHandleForPayloadOp(payload) == getOperand() &&
165+
"inconsistent mapping between transform IR handles and payload IR "
166+
"operations");
167+
}
168+
169+
return success();
170+
}
171+
172+
LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply(
173+
transform::TransformResults &results, transform::TransformState &state) {
174+
auto *extension = state.getExtension<TestTransformStateExtension>();
175+
if (!extension)
176+
return emitError() << "TestTransformStateExtension missing";
177+
178+
return extension->updateMapping(state.getPayloadOps(getOperand()).front(),
179+
getOperation());
180+
}
181+
182+
LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
183+
transform::TransformResults &results, transform::TransformState &state) {
184+
state.removeExtension<TestTransformStateExtension>();
185+
return success();
186+
}
187+
145188
namespace {
146189
/// Test extension of the Transform dialect. Registers additional ops and
147190
/// declares PDL as dependent dialect since the additional ops are using PDL

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,41 @@ def TestPrintRemarkAtOperandOp
5656
let cppNamespace = "::mlir::test";
5757
}
5858

59+
def TestAddTestExtensionOp
60+
: Op<Transform_Dialect, "test_add_test_extension",
61+
[DeclareOpInterfaceMethods<TransformOpInterface>,
62+
NoSideEffect]> {
63+
let arguments = (ins StrAttr:$message);
64+
let assemblyFormat = "$message attr-dict";
65+
let cppNamespace = "::mlir::test";
66+
}
67+
68+
def TestCheckIfTestExtensionPresentOp
69+
: Op<Transform_Dialect, "test_check_if_test_extension_present",
70+
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
71+
let arguments = (ins
72+
Arg<PDL_Operation, "", [TransformMappingRead, PayloadIRRead]>:$operand);
73+
let assemblyFormat = "$operand attr-dict";
74+
let cppNamespace = "::mlir::test";
75+
}
76+
77+
def TestRemapOperandPayloadToSelfOp
78+
: Op<Transform_Dialect, "test_remap_operand_to_self",
79+
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
80+
let arguments = (ins
81+
Arg<PDL_Operation, "",
82+
[TransformMappingRead, TransformMappingWrite, PayloadIRRead]>:$operand);
83+
let assemblyFormat = "$operand attr-dict";
84+
let cppNamespace = "::mlir::test";
85+
}
86+
87+
def TestRemoveTestExtensionOp
88+
: Op<Transform_Dialect, "test_remove_test_extension",
89+
[DeclareOpInterfaceMethods<TransformOpInterface>,
90+
NoSideEffect]> {
91+
let assemblyFormat = "attr-dict";
92+
let cppNamespace = "::mlir::test";
93+
}
94+
95+
5996
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines an TransformState extension for the purpose of testing the
10+
// relevant APIs.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
15+
#define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
16+
17+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18+
19+
using namespace mlir;
20+
21+
namespace mlir {
22+
namespace test {
23+
class TestTransformStateExtension
24+
: public transform::TransformState::Extension {
25+
public:
26+
TestTransformStateExtension(transform::TransformState &state,
27+
StringAttr message)
28+
: Extension(state), message(message) {}
29+
30+
StringRef getMessage() const { return message.getValue(); }
31+
32+
LogicalResult updateMapping(Operation *previous, Operation *updated) {
33+
return replacePayloadOp(previous, updated);
34+
}
35+
36+
private:
37+
StringAttr message;
38+
};
39+
} // namespace test
40+
} // namespace mlir
41+
42+
#endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H

0 commit comments

Comments
 (0)