Skip to content

[mlir] expose transform dialect symbol merge to python #87690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Apr 4, 2024

This functionality is available in C++, make it available in Python directly to operate on transform modules.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Apr 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

This functionality is available in C++, make it available in Python directly to operate on transform modules.


Full diff: https://github.com/llvm/llvm-project/pull/87690.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Transform/Interpreter.h (+11-1)
  • (modified) mlir/lib/Bindings/Python/TransformInterpreter.cpp (+15)
  • (modified) mlir/lib/CAPI/Dialect/TransformInterpreter.cpp (+9)
  • (modified) mlir/python/mlir/dialects/transform/interpreter/init.py (+7-1)
  • (modified) mlir/test/python/dialects/transform_interpreter.py (+76)
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
index 00095d5040a0e5..fa320324234e8d 100644
--- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
 mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
 
 //----------------------------------------------------------------------------//
-// Transform interpreter.
+// Transform interpreter and utilities.
 //----------------------------------------------------------------------------//
 
 /// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
     MlirOperation payload, MlirOperation transformRoot,
     MlirOperation transformModule, MlirTransformOptions transformOptions);
 
+/// Merge the symbols from `other` into `target`, potentially renaming them to
+/// avoid conflicts. Private symbols may be renamed during the merge, public
+/// symbols must have at most one declaration. A name conflict in public symbols
+/// is reported as an error before returning a failure.
+///
+/// Note that this clones the `other` operation unlike the C++ counterpart that
+/// takes ownership.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 6517f8c39dfadd..6448ae433b5c3f 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
       py::arg("payload_root"), py::arg("transform_root"),
       py::arg("transform_module"),
       py::arg("transform_options") = PyMlirTransformOptions());
+
+  m.def(
+      "merge_symbols_into",
+      [](MlirOperation target, MlirOperation other) {
+        mlir::python::CollectDiagnosticsToStringScope scope(
+            mlirOperationGetContext(target));
+
+        MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+        if (mlirLogicalResultIsSuccess(result))
+          return;
+
+        throw py::value_error("Failed to merge symbols.\nDiagnostic message " +
+                              scope.takeMessage());
+      },
+      py::arg("target"), py::arg("other"));
 }
 
 PYBIND11_MODULE(_mlirTransformInterpreter, m) {
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
index eb6951dc5584d6..145455e1c1b3d2 100644
--- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -15,6 +15,7 @@
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 
@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
       unwrap(payload), unwrap(transformRoot),
       cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
 }
+
+MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
+                                                MlirOperation other) {
+  OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
+  LogicalResult result = transform::detail::mergeSymbolsInto(
+      unwrap(target), std::move(otherOwning));
+  return wrap(result);
+}
 }
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
index 6145b99224eb54..4de827257174ab 100644
--- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -5,7 +5,6 @@
 from ....ir import Operation
 from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
 
-
 TransformOptions = _cextTransformInterpreter.TransformOptions
 
 
@@ -31,3 +30,10 @@ def apply_named_sequence(
         _cextTransformInterpreter.apply_named_sequence(*args)
     else:
         _cextTransformInterpreter(*args, transform_options)
+
+
+def merge_symbols_into(target, other):
+    """Copies symbols from other into target, renaming private symbols to avoid duplicates. Raises an error if copying would lead to duplicate public symbols."""
+    _cextTransformInterpreter.merge_symbols_into(
+        _unpack_operation(target), _unpack_operation(other)
+    )
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 740c49f76a26c4..d3ada7f32d8d59 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -54,3 +54,79 @@ def failed():
         assert (
             "must implement TransformOpInterface to be used as transform root" in str(e)
         )
+
+
+print_root_via_include_module = """
+module @print_root_via_include_module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    transform.include @callee2 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}"""
+
+callee2_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
+    transform.include @callee1 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+"""
+
+callee1_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
+    transform.print %root { name = \"from interpreter\" }: !transform.any_op
+    transform.yield
+  }
+}
+"""
+
+
+@test_in_context
+def include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee1 = ir.Module.parse(callee1_definition)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee1)
+    interp.merge_symbols_into(main, callee2)
+
+    # CHECK: @print_root_via_include_module
+    # CHECK: transform.named_sequence @__transform_main
+    # CHECK: transform.include @callee2
+    #
+    # CHECK: transform.named_sequence @callee1
+    # CHECK: transform.print
+    #
+    # CHECK: transform.named_sequence @callee2
+    # CHECK: transform.include @callee1
+    interp.apply_named_sequence(main, main.body.operations[0], main)
+
+
+@test_in_context
+def partial_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee2)
+
+    try:
+        interp.apply_named_sequence(main, main.body.operations[0], main)
+    except ValueError as e:
+        assert "Failed to apply" in str(e)
+
+
+@test_in_context
+def repeated_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee2)
+
+    try:
+        interp.merge_symbols_into(main, callee2)
+    except ValueError as e:
+        assert "doubly defined symbol @callee2" in str(e)

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: why the copy vs move of operations?

For reference, one of several of these I've written in Python over the years.

@ftynse
Copy link
Member Author

ftynse commented Apr 16, 2024

I don't see how one can implement a move in a reasonable way in Python. It is possible to mark the moved-from operation as invalid at the bindings level and raise an error on the second access, but it isn't something the language naturally has and is therefore surprising. OTOH, the anticipated use case is "load library A, merge symbols from A into script X, merge symbols from B into script Y, ..." so we don't want to consume the library.

@makslevental
Copy link
Contributor

I don't see how one can implement a move in a reasonable way in Python. It is possible to mark the moved-from operation as invalid at the bindings level and raise an error on the second access, but it isn't something the language naturally has and is therefore surprising. OTOH, the anticipated use case is "load library A, merge symbols from A into script X, merge symbols from B into script Y, ..." so we don't want to consume the library.

+1 from me on the lack of move semantics in Python and also the use case. A nit could be "well then it should be named copy_symbols_into" but then you lose the clear/loud indication that things will be renamed.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This functionality is available in C++, make it available in Python
directly to operate on transform modules.
@ftynse
Copy link
Member Author

ftynse commented Apr 17, 2024

A nit could be "well then it should be named copy_symbols_into" but then you lose the clear/loud indication that things will be renamed.

Renamed to copy_symbols_and_merge_into.

@ftynse ftynse merged commit 73140da into llvm:main Apr 17, 2024
3 of 4 checks passed
@ftynse ftynse deleted the td-merge branch April 17, 2024 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants