From 87f8596229eb97195176b676764597568f269dd4 Mon Sep 17 00:00:00 2001 From: max Date: Thu, 2 Nov 2023 21:41:54 -0500 Subject: [PATCH] [mlir][python] C++ API demo --- mlir/python/CMakeLists.txt | 2 + mlir/python/mlir/dialects/python_test.py | 6 +++ mlir/test/python/dialects/python_test.py | 35 +++++++++++++++ mlir/test/python/lib/CMakeLists.txt | 2 +- mlir/test/python/lib/PythonTestModule.cpp | 6 +++ mlir/test/python/lib/PythonTestPass.cpp | 53 +++++++++++++++++++++++ mlir/test/python/lib/PythonTestPass.h | 16 +++++++ 7 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 mlir/test/python/lib/PythonTestPass.cpp create mode 100644 mlir/test/python/lib/PythonTestPass.h diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 88e6e13602d29..e45f55e038106 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS) ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" SOURCES PythonTestModule.cpp + PythonTestPass.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPIPythonTestDialect ) + set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti) endif() ################################################################################ diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 6579e02d8549e..401ac260139e0 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -15,3 +15,9 @@ def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest _mlirPythonTest.register_python_test_dialect(context, load) + + +def register_python_test_pass_demo_pass(func): + from .._mlir_libs import _mlirPythonTest + + _mlirPythonTest.register_python_test_pass_demo_pass(func) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 472db7e5124db..39e3a2dc8ee45 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -5,6 +5,7 @@ import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor import mlir.dialects.arith as arith +from mlir.passmanager import PassManager def run(f): @@ -551,3 +552,37 @@ def testInferTypeOpInterface(): two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) # CHECK: f32 print(two_operands.result.type) + + +# CHECK-LABEL: testPythonPassDemo +@run +def testPythonPassDemo(): + def print_ops(op): + print(op.name) + + module = """ + module { + func.func @main() { + %memref = memref.alloca() : memref<1xi64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i64 + memref.store %c1, %memref[%c0] : memref<1xi64> + %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64> + return + } + } + """ + + # CHECK: memref.alloca + # CHECK: arith.constant + # CHECK: arith.constant + # CHECK: memref.store + # CHECK: memref.cast + # CHECK: func.return + # CHECK: func.func + # CHECK: builtin.module + with Context() as ctx, Location.unknown(): + test.register_python_test_dialect(ctx) + test.register_python_test_pass_demo_pass(print_ops) + mlir_module = Module.parse(module) + PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation) diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt index d7cbbfbc21477..8354a08e7b713 100644 --- a/mlir/test/python/lib/CMakeLists.txt +++ b/mlir/test/python/lib/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES PythonTestCAPI.cpp PythonTestDialect.cpp PythonTestModule.cpp + PythonTestPass.cpp ) add_mlir_library(MLIRPythonTestDialect @@ -29,4 +30,3 @@ add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect MLIRCAPIIR MLIRPythonTestDialect ) - diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index f533082a0a147..5be16e37abc33 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "PythonTestCAPI.h" +#include "PythonTestPass.h" + #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/IR.h" @@ -34,6 +36,10 @@ PYBIND11_MODULE(_mlirPythonTest, m) { }, py::arg("context"), py::arg("load") = true); + m.def("register_python_test_pass_demo_pass", [](py::function func) { + registerPythonTestPassDemoPassWithFunc(func.ptr()); + }); + mlir_attribute_subclass(m, "TestAttr", mlirAttributeIsAPythonTestTestAttribute) .def_classmethod( diff --git a/mlir/test/python/lib/PythonTestPass.cpp b/mlir/test/python/lib/PythonTestPass.cpp new file mode 100644 index 0000000000000..ce2b9f34c3c5c --- /dev/null +++ b/mlir/test/python/lib/PythonTestPass.cpp @@ -0,0 +1,53 @@ +//===- PythonTestPassDemo.cpp -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestPass.h" + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir/CAPI/IR.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct PythonTestPassDemo + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonTestPassDemo) + + PythonTestPassDemo(PyObject *func) : func(func) {} + StringRef getArgument() const final { return "python-pass-demo"; } + + void runOnOperation() override { + this->getOperation()->walk([this](Operation *op) { + PyObject *mlirModule = + PyImport_ImportModule(MAKE_MLIR_PYTHON_QUALNAME("ir")); + PyObject *cAPIFactory = PyObject_GetAttrString( + PyObject_GetAttrString(mlirModule, "Operation"), + MLIR_PYTHON_CAPI_FACTORY_ATTR); + PyObject *opApiObject = PyObject_CallFunction( + cAPIFactory, "(O)", mlirPythonOperationToCapsule(wrap(op))); + (void)PyObject_CallFunction(func, "(O)", opApiObject); + Py_DECREF(opApiObject); + }); + } + + PyObject *func; +}; + +std::unique_ptr> +createPythonTestPassDemoPassWithFunc(PyObject *func) { + return std::make_unique(func); +} + +} // namespace + +void registerPythonTestPassDemoPassWithFunc(PyObject *func) { + registerPass([func]() { return createPythonTestPassDemoPassWithFunc(func); }); +} diff --git a/mlir/test/python/lib/PythonTestPass.h b/mlir/test/python/lib/PythonTestPass.h new file mode 100644 index 0000000000000..4df4f965857ed --- /dev/null +++ b/mlir/test/python/lib/PythonTestPass.h @@ -0,0 +1,16 @@ +//===- PythonTestPassDemo.h ---------------------------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H +#define MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H + +#include + +void registerPythonTestPassDemoPassWithFunc(PyObject *func); + +#endif