Skip to content

[mlir][python] C++ API demo #71133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Nov 3, 2023

This PR demonstrates how to call C++ APIs directly without going through the C API and without RTTI.

In fact it demonstrates something more exciting: tunneling python callbacks all the way into MLIR passes:

def print_ops(op):
    print(op.name)

test.register_python_test_pass_demo_pass(print_ops)

module = """
module {
  func.func @main() {
    %memref = memref.alloca() : memref<1xi64>
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : i64
    memref.store %c1, %memref[%c0] : memref<1xi64>
    %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
    return
  }
}
"""

mlir_module = Module.parse(module)
PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)

This will print

memref.alloca
arith.constant
arith.constant
memref.store
memref.cast
func.return
func.func
builtin.module

by calling print_ops from inside the pass (where there's a this->getOperation()->walk([this](Operation *op) { ... });).

Note, there's some boilerplate Python C API munging for going between MLIR-C API types (like MlirOperation) and Python API objects (like ir.Operation) that can be polished up for public consumption (maybe to live in mlir-c/Bindings/Python/Interop.h) but before typing all those words I wanted to gauge interest.

Also, I probably need some more Py_DECREFs but I'll need to run this through ASAN to be sure.

@makslevental makslevental marked this pull request as ready for review November 3, 2023 02:46
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Nov 3, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 3, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This PR demonstrates how to call C++ APIs directly without going through the C API and without RTTI.

In fact it demonstrates something more exciting: tunneling python callbacks all the way into MLIR passes:

def print_ops(op):
    print(op.name)

test.register_python_test_pass_demo_pass(print_ops)

module = """
module {
  func.func @<!-- -->main() {
    %memref = memref.alloca() : memref&lt;1xi64&gt;
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : i64
    memref.store %c1, %memref[%c0] : memref&lt;1xi64&gt;
    %u_memref = memref.cast %memref : memref&lt;1xi64&gt; to memref&lt;*xi64&gt;
    return
  }
}
"""

mlir_module = Module.parse(module)
PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)

This will print

memref.alloca
arith.constant
arith.constant
memref.store
memref.cast
func.return
func.func
builtin.module

by calling print_ops from inside the pass (where there's a this-&gt;getOperation()-&gt;walk([this](Operation *op) { ... });).


Full diff: https://github.com/llvm/llvm-project/pull/71133.diff

7 Files Affected:

  • (modified) mlir/python/CMakeLists.txt (+2)
  • (modified) mlir/python/mlir/dialects/python_test.py (+6)
  • (modified) mlir/test/python/dialects/python_test.py (+35)
  • (modified) mlir/test/python/lib/CMakeLists.txt (+1-1)
  • (modified) mlir/test/python/lib/PythonTestModule.cpp (+6)
  • (added) mlir/test/python/lib/PythonTestPass.cpp (+53)
  • (added) mlir/test/python/lib/PythonTestPass.h (+16)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 88e6e13602d291a..e45f55e0381063d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS)
     ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
     SOURCES
       PythonTestModule.cpp
+      PythonTestPass.cpp
     PRIVATE_LINK_LIBS
       LLVMSupport
     EMBED_CAPI_LINK_LIBS
       MLIRCAPIPythonTestDialect
   )
+  set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
 endif()
 
 ################################################################################
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..401ac260139e0a1 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -15,3 +15,9 @@ def register_python_test_dialect(context, load=True):
     from .._mlir_libs import _mlirPythonTest
 
     _mlirPythonTest.register_python_test_dialect(context, load)
+
+
+def register_python_test_pass_demo_pass(func):
+    from .._mlir_libs import _mlirPythonTest
+
+    _mlirPythonTest.register_python_test_pass_demo_pass(func)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 472db7e5124dbed..39e3a2dc8ee45ca 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,6 +5,7 @@
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 import mlir.dialects.arith as arith
+from mlir.passmanager import PassManager
 
 
 def run(f):
@@ -551,3 +552,37 @@ def testInferTypeOpInterface():
             two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
             # CHECK: f32
             print(two_operands.result.type)
+
+
+# CHECK-LABEL: testPythonPassDemo
+@run
+def testPythonPassDemo():
+    def print_ops(op):
+        print(op.name)
+
+    module = """
+    module {
+      func.func @main() {
+        %memref = memref.alloca() : memref<1xi64>
+        %c0 = arith.constant 0 : index
+        %c1 = arith.constant 1 : i64
+        memref.store %c1, %memref[%c0] : memref<1xi64>
+        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
+        return
+      }
+    }
+    """
+
+    # CHECK: memref.alloca
+    # CHECK: arith.constant
+    # CHECK: arith.constant
+    # CHECK: memref.store
+    # CHECK: memref.cast
+    # CHECK: func.return
+    # CHECK: func.func
+    # CHECK: builtin.module
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
+        test.register_python_test_pass_demo_pass(print_ops)
+        mlir_module = Module.parse(module)
+        PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index d7cbbfbc214772b..8354a08e7b7139f 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   PythonTestCAPI.cpp
   PythonTestDialect.cpp
   PythonTestModule.cpp
+  PythonTestPass.cpp
 )
 
 add_mlir_library(MLIRPythonTestDialect
@@ -29,4 +30,3 @@ add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
   MLIRCAPIIR
   MLIRPythonTestDialect
 )
-
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..5be16e37abc338c 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "PythonTestCAPI.h"
+#include "PythonTestPass.h"
+
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/IR.h"
@@ -34,6 +36,10 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       },
       py::arg("context"), py::arg("load") = true);
 
+  m.def("register_python_test_pass_demo_pass", [](py::function func) {
+    registerPythonTestPassDemoPassWithFunc(func.ptr());
+  });
+
   mlir_attribute_subclass(m, "TestAttr",
                           mlirAttributeIsAPythonTestTestAttribute)
       .def_classmethod(
diff --git a/mlir/test/python/lib/PythonTestPass.cpp b/mlir/test/python/lib/PythonTestPass.cpp
new file mode 100644
index 000000000000000..ce2b9f34c3c5c3d
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.cpp
@@ -0,0 +1,53 @@
+//===- PythonTestPassDemo.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestPass.h"
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct PythonTestPassDemo
+    : public PassWrapper<PythonTestPassDemo, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonTestPassDemo)
+
+  PythonTestPassDemo(PyObject *func) : func(func) {}
+  StringRef getArgument() const final { return "python-pass-demo"; }
+
+  void runOnOperation() override {
+    this->getOperation()->walk([this](Operation *op) {
+      PyObject *mlirModule =
+          PyImport_ImportModule(MAKE_MLIR_PYTHON_QUALNAME("ir"));
+      PyObject *cAPIFactory = PyObject_GetAttrString(
+          PyObject_GetAttrString(mlirModule, "Operation"),
+          MLIR_PYTHON_CAPI_FACTORY_ATTR);
+      PyObject *opApiObject = PyObject_CallFunction(
+          cAPIFactory, "(O)", mlirPythonOperationToCapsule(wrap(op)));
+      (void)PyObject_CallFunction(func, "(O)", opApiObject);
+      Py_DECREF(opApiObject);
+    });
+  }
+
+  PyObject *func;
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createPythonTestPassDemoPassWithFunc(PyObject *func) {
+  return std::make_unique<PythonTestPassDemo>(func);
+}
+
+} // namespace
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func) {
+  registerPass([func]() { return createPythonTestPassDemoPassWithFunc(func); });
+}
diff --git a/mlir/test/python/lib/PythonTestPass.h b/mlir/test/python/lib/PythonTestPass.h
new file mode 100644
index 000000000000000..4df4f965857eda6
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.h
@@ -0,0 +1,16 @@
+//===- PythonTestPassDemo.h ---------------------------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+#define MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+
+#include <Python.h>
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func);
+
+#endif

@makslevental makslevental requested a review from rkayaith November 3, 2023 02:53
@martin-luecke
Copy link
Contributor

Interesting demonstration! I think this is useful to have.

test.register_python_test_dialect(ctx)
test.register_python_test_pass_demo_pass(print_ops)
mlir_module = Module.parse(module)
PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is intended as a demo, can you make it into its own file, with extensive documentation: "example style"

PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIPythonTestDialect
)
set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, anything linked into a C extension module that's being built using pybind cmake helpers. The problem is PythonPass, which inherits from mlir::Pass, will then "depend" on the RTTI for mlir::Pass1. So you can make this example work by compiling just PythonPass.cpp without RTTI2.

Footnotes

  1. See this godbolt where the typeinfo for B dereferences the typeinfo for A even though that typeinfo doesn't appear anywhere.

  2. Or you can make this example work by building MLIR with ENABLE_RTTI=ON but that's a non-starter for us (right...?).

# CHECK: builtin.module
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
test.register_python_test_pass_demo_pass(print_ops)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a dangerous pattern to me: what happens when print_ops goes out of scope? (I mean it can't here, but you're registering a PyObject so it could...)

Also can you write in the same file a second run of the pipeline with a different print_ops which would print with a prefix for example? I suspect the registration being global that won't work...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a dangerous pattern to me: what happens when print_ops goes out of scope? (I mean it can't here, but you're registering a PyObject so it could...)

The bandaid is to put print_ops into threading.local() right?

Also can you write in the same file a second run of the pipeline with a different print_ops which would print with a prefix for example? I suspect the registration being global that won't work...

Yea sure but you can't register a single pass (in asserts mode...) multiple times anyway so we shouldn't expect that to work "afortiori".

@jpienaar
Copy link
Member

jpienaar commented Nov 6, 2023

The place where I think this could be a nice example is in the "to graphviz" pass: make it easy to customize shapes from Python :)

@ftynse
Copy link
Member

ftynse commented Nov 6, 2023

This PR demonstrates how to call C++ APIs directly without going through the C API and without RTTI.

FTR, the bindings are based on C API rather than C++ directly for API stability reasons, not because we couldn't make it work with C++.

@@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS)
ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
SOURCES
PythonTestModule.cpp
PythonTestPass.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pulls in mlir c++ libs (built-in and pass) but it's not depending on them here. It appears they just happen to be resolved via MLIRCAPIPythonTestDialect.

I don't think this is going to work, and there is not an easy fix. This is relying in the capi lib incidentally exporting symbols for its backing c++ API. That is dependent on how the project is built: it happens to be true in this dev tree but will not work with hidden visibility, like real packages use. The other option you may be tempted to do is add the things you need to PRIVATE_LINK_LIBS, but that is also fraught. Not only will it duplicate the backing library code in the extension, TypeID linkage will no longer be single, strongly rooted. This will create the dreaded vague linkage issues on Linux and has no path to work on Windows.

I don't have a good suggestion: this is why we did not get adventurous on some of this stuff.

@makslevental
Copy link
Contributor Author

makslevental commented Nov 6, 2023

@joker-eph @stellaraccident

Regarding footguns

This pulls in mlir c++ libs (built-in and pass) but it's not depending on them here. It appears they just happen to be resolved via MLIRCAPIPythonTestDialect.

Hence the name of the branch being hold_my_beer 😄. I fully admit this isn't for the faint of heart but the functionality implied is really very "powerful" isn't it? Build a pass to do some rewrite as a function of say calling networkx or an ILP solver but you don't need to link anything at runtime or compile, just munge some Python. Anyway it doesn't matter to me if it doesn't get merged. The PR basically already serves the purpose of being a minimal POC.

but will not work with hidden visibility, like real packages use.

This one I thought of and I will need to solve in the other places I intend to use this - my thought is to statically link everything (the aggregate and the extension modules) into one shared object (like Triton does) and then private symbols aren't a problem.

Not only will it duplicate the backing library code in the extension, TypeID linkage will no longer be single, strongly rooted. This will create the dreaded vague linkage issues on Linux and has no path to work on Windows.

Yes I've tried this and of course you get that problem and it shows up as "duplicate registrations" for things dialects and passes and etc. Again my thought is keeping everything in shared object file solves these problems but maybe I'm wrong.

@makslevental
Copy link
Contributor Author

makslevental commented Nov 6, 2023

@ftynse

FTR, the bindings are based on C API rather than C++ directly for API stability reasons, not because we couldn't make it work with C++.

Upstream (i.e., MLIR) API stability aside I'm curious: did you guys consider Custom automatic downcasters, which seems designed with us exactly in mind:

Sometimes, you might want to provide this automatic downcasting behavior when creating bindings for a class hierarchy that does not use standard C++ polymorphism, such as LLVM [3].

where the footnote is a link to https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html. This suggests we could eschew RTTI in the bindings entirely and get closer affinity with the canonical (C++) implementations...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants