Skip to content

Commit 73140da

Browse files
authored
[mlir] expose transform dialect symbol merge to python (#87690)
This functionality is available in C++, make it available in Python directly to operate on transform modules.
1 parent 971ec1f commit 73140da

File tree

5 files changed

+120
-2
lines changed

5 files changed

+120
-2
lines changed

mlir/include/mlir-c/Dialect/Transform/Interpreter.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
6060
mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
6161

6262
//----------------------------------------------------------------------------//
63-
// Transform interpreter.
63+
// Transform interpreter and utilities.
6464
//----------------------------------------------------------------------------//
6565

6666
/// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
7272
MlirOperation payload, MlirOperation transformRoot,
7373
MlirOperation transformModule, MlirTransformOptions transformOptions);
7474

75+
/// Merge the symbols from `other` into `target`, potentially renaming them to
76+
/// avoid conflicts. Private symbols may be renamed during the merge, public
77+
/// symbols must have at most one declaration. A name conflict in public symbols
78+
/// is reported as an error before returning a failure.
79+
///
80+
/// Note that this clones the `other` operation unlike the C++ counterpart that
81+
/// takes ownership.
82+
MLIR_CAPI_EXPORTED MlirLogicalResult
83+
mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
84+
7585
#ifdef __cplusplus
7686
}
7787
#endif

mlir/lib/Bindings/Python/TransformInterpreter.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
8282
py::arg("payload_root"), py::arg("transform_root"),
8383
py::arg("transform_module"),
8484
py::arg("transform_options") = PyMlirTransformOptions());
85+
86+
m.def(
87+
"copy_symbols_and_merge_into",
88+
[](MlirOperation target, MlirOperation other) {
89+
mlir::python::CollectDiagnosticsToStringScope scope(
90+
mlirOperationGetContext(target));
91+
92+
MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
93+
if (mlirLogicalResultIsFailure(result)) {
94+
throw py::value_error(
95+
"Failed to merge symbols.\nDiagnostic message " +
96+
scope.takeMessage());
97+
}
98+
},
99+
py::arg("target"), py::arg("other"));
85100
}
86101

87102
PYBIND11_MODULE(_mlirTransformInterpreter, m) {

mlir/lib/CAPI/Dialect/TransformInterpreter.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/CAPI/IR.h"
1616
#include "mlir/CAPI/Support.h"
1717
#include "mlir/CAPI/Wrap.h"
18+
#include "mlir/Dialect/Transform/IR/Utils.h"
1819
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1920
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
2021

@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
7172
unwrap(payload), unwrap(transformRoot),
7273
cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
7374
}
75+
76+
MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
77+
MlirOperation other) {
78+
OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
79+
LogicalResult result = transform::detail::mergeSymbolsInto(
80+
unwrap(target), std::move(otherOwning));
81+
return wrap(result);
82+
}
7483
}

mlir/python/mlir/dialects/transform/interpreter/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ....ir import Operation
66
from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
77

8-
98
TransformOptions = _cextTransformInterpreter.TransformOptions
109

1110

@@ -31,3 +30,12 @@ def apply_named_sequence(
3130
_cextTransformInterpreter.apply_named_sequence(*args)
3231
else:
3332
_cextTransformInterpreter(*args, transform_options)
33+
34+
35+
def copy_symbols_and_merge_into(target, other):
36+
"""Copies symbols from other into target, renaming private symbols to avoid
37+
duplicates. Raises an error if copying would lead to duplicate public
38+
symbols."""
39+
_cextTransformInterpreter.copy_symbols_and_merge_into(
40+
_unpack_operation(target), _unpack_operation(other)
41+
)

mlir/test/python/dialects/transform_interpreter.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,79 @@ def failed():
5454
assert (
5555
"must implement TransformOpInterface to be used as transform root" in str(e)
5656
)
57+
58+
59+
print_root_via_include_module = """
60+
module @print_root_via_include_module attributes {transform.with_named_sequence} {
61+
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
62+
transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
63+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
64+
transform.include @callee2 failures(propagate)
65+
(%root) : (!transform.any_op) -> ()
66+
transform.yield
67+
}
68+
}"""
69+
70+
callee2_definition = """
71+
module attributes {transform.with_named_sequence} {
72+
transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
73+
transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
74+
transform.include @callee1 failures(propagate)
75+
(%root) : (!transform.any_op) -> ()
76+
transform.yield
77+
}
78+
}
79+
"""
80+
81+
callee1_definition = """
82+
module attributes {transform.with_named_sequence} {
83+
transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
84+
transform.print %root { name = \"from interpreter\" }: !transform.any_op
85+
transform.yield
86+
}
87+
}
88+
"""
89+
90+
91+
@test_in_context
92+
def include():
93+
main = ir.Module.parse(print_root_via_include_module)
94+
callee1 = ir.Module.parse(callee1_definition)
95+
callee2 = ir.Module.parse(callee2_definition)
96+
interp.copy_symbols_and_merge_into(main, callee1)
97+
interp.copy_symbols_and_merge_into(main, callee2)
98+
99+
# CHECK: @print_root_via_include_module
100+
# CHECK: transform.named_sequence @__transform_main
101+
# CHECK: transform.include @callee2
102+
#
103+
# CHECK: transform.named_sequence @callee1
104+
# CHECK: transform.print
105+
#
106+
# CHECK: transform.named_sequence @callee2
107+
# CHECK: transform.include @callee1
108+
interp.apply_named_sequence(main, main.body.operations[0], main)
109+
110+
111+
@test_in_context
112+
def partial_include():
113+
main = ir.Module.parse(print_root_via_include_module)
114+
callee2 = ir.Module.parse(callee2_definition)
115+
interp.copy_symbols_and_merge_into(main, callee2)
116+
117+
try:
118+
interp.apply_named_sequence(main, main.body.operations[0], main)
119+
except ValueError as e:
120+
assert "Failed to apply" in str(e)
121+
122+
123+
@test_in_context
124+
def repeated_include():
125+
main = ir.Module.parse(print_root_via_include_module)
126+
callee2 = ir.Module.parse(callee2_definition)
127+
interp.copy_symbols_and_merge_into(main, callee2)
128+
129+
try:
130+
interp.copy_symbols_and_merge_into(main, callee2)
131+
except ValueError as e:
132+
assert "doubly defined symbol @callee2" in str(e)

0 commit comments

Comments
 (0)