Skip to content

Commit 87f8596

Browse files
committed
[mlir][python] C++ API demo
1 parent 801a30a commit 87f8596

File tree

7 files changed

+119
-1
lines changed

7 files changed

+119
-1
lines changed

mlir/python/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS)
607607
ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
608608
SOURCES
609609
PythonTestModule.cpp
610+
PythonTestPass.cpp
610611
PRIVATE_LINK_LIBS
611612
LLVMSupport
612613
EMBED_CAPI_LINK_LIBS
613614
MLIRCAPIPythonTestDialect
614615
)
616+
set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
615617
endif()
616618

617619
################################################################################

mlir/python/mlir/dialects/python_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,9 @@ def register_python_test_dialect(context, load=True):
1515
from .._mlir_libs import _mlirPythonTest
1616

1717
_mlirPythonTest.register_python_test_dialect(context, load)
18+
19+
20+
def register_python_test_pass_demo_pass(func):
21+
from .._mlir_libs import _mlirPythonTest
22+
23+
_mlirPythonTest.register_python_test_pass_demo_pass(func)

mlir/test/python/dialects/python_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import mlir.dialects.python_test as test
66
import mlir.dialects.tensor as tensor
77
import mlir.dialects.arith as arith
8+
from mlir.passmanager import PassManager
89

910

1011
def run(f):
@@ -551,3 +552,37 @@ def testInferTypeOpInterface():
551552
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
552553
# CHECK: f32
553554
print(two_operands.result.type)
555+
556+
557+
# CHECK-LABEL: testPythonPassDemo
558+
@run
559+
def testPythonPassDemo():
560+
def print_ops(op):
561+
print(op.name)
562+
563+
module = """
564+
module {
565+
func.func @main() {
566+
%memref = memref.alloca() : memref<1xi64>
567+
%c0 = arith.constant 0 : index
568+
%c1 = arith.constant 1 : i64
569+
memref.store %c1, %memref[%c0] : memref<1xi64>
570+
%u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
571+
return
572+
}
573+
}
574+
"""
575+
576+
# CHECK: memref.alloca
577+
# CHECK: arith.constant
578+
# CHECK: arith.constant
579+
# CHECK: memref.store
580+
# CHECK: memref.cast
581+
# CHECK: func.return
582+
# CHECK: func.func
583+
# CHECK: builtin.module
584+
with Context() as ctx, Location.unknown():
585+
test.register_python_test_dialect(ctx)
586+
test.register_python_test_pass_demo_pass(print_ops)
587+
mlir_module = Module.parse(module)
588+
PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)

mlir/test/python/lib/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
22
PythonTestCAPI.cpp
33
PythonTestDialect.cpp
44
PythonTestModule.cpp
5+
PythonTestPass.cpp
56
)
67

78
add_mlir_library(MLIRPythonTestDialect
@@ -29,4 +30,3 @@ add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
2930
MLIRCAPIIR
3031
MLIRPythonTestDialect
3132
)
32-

mlir/test/python/lib/PythonTestModule.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "PythonTestCAPI.h"
10+
#include "PythonTestPass.h"
11+
1012
#include "mlir-c/BuiltinAttributes.h"
1113
#include "mlir-c/BuiltinTypes.h"
1214
#include "mlir-c/IR.h"
@@ -34,6 +36,10 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
3436
},
3537
py::arg("context"), py::arg("load") = true);
3638

39+
m.def("register_python_test_pass_demo_pass", [](py::function func) {
40+
registerPythonTestPassDemoPassWithFunc(func.ptr());
41+
});
42+
3743
mlir_attribute_subclass(m, "TestAttr",
3844
mlirAttributeIsAPythonTestTestAttribute)
3945
.def_classmethod(
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- PythonTestPassDemo.cpp -----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PythonTestPass.h"
10+
11+
#include "mlir-c/Bindings/Python/Interop.h"
12+
#include "mlir/CAPI/IR.h"
13+
#include "mlir/IR/BuiltinDialect.h"
14+
#include "mlir/Pass/Pass.h"
15+
16+
using namespace mlir;
17+
18+
namespace {
19+
20+
struct PythonTestPassDemo
21+
: public PassWrapper<PythonTestPassDemo, OperationPass<ModuleOp>> {
22+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonTestPassDemo)
23+
24+
PythonTestPassDemo(PyObject *func) : func(func) {}
25+
StringRef getArgument() const final { return "python-pass-demo"; }
26+
27+
void runOnOperation() override {
28+
this->getOperation()->walk([this](Operation *op) {
29+
PyObject *mlirModule =
30+
PyImport_ImportModule(MAKE_MLIR_PYTHON_QUALNAME("ir"));
31+
PyObject *cAPIFactory = PyObject_GetAttrString(
32+
PyObject_GetAttrString(mlirModule, "Operation"),
33+
MLIR_PYTHON_CAPI_FACTORY_ATTR);
34+
PyObject *opApiObject = PyObject_CallFunction(
35+
cAPIFactory, "(O)", mlirPythonOperationToCapsule(wrap(op)));
36+
(void)PyObject_CallFunction(func, "(O)", opApiObject);
37+
Py_DECREF(opApiObject);
38+
});
39+
}
40+
41+
PyObject *func;
42+
};
43+
44+
std::unique_ptr<OperationPass<ModuleOp>>
45+
createPythonTestPassDemoPassWithFunc(PyObject *func) {
46+
return std::make_unique<PythonTestPassDemo>(func);
47+
}
48+
49+
} // namespace
50+
51+
void registerPythonTestPassDemoPassWithFunc(PyObject *func) {
52+
registerPass([func]() { return createPythonTestPassDemoPassWithFunc(func); });
53+
}

mlir/test/python/lib/PythonTestPass.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- PythonTestPassDemo.h ---------------------------------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
10+
#define MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
11+
12+
#include <Python.h>
13+
14+
void registerPythonTestPassDemoPassWithFunc(PyObject *func);
15+
16+
#endif

0 commit comments

Comments
 (0)