Skip to content

[mlir][python] set the registry free #72477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 115 additions & 98 deletions mlir/python/mlir/_mlir_libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,89 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Sequence

import importlib
import itertools
import logging
import os
import sys
from typing import Sequence

from ._mlir import ir


_this_dir = os.path.dirname(__file__)

logger = logging.getLogger(__name__)

_path = __path__
_spec = __spec__
_name = __name__


class _M:
__path__ = _path
__spec__ = _spec
__name__ = _name

@staticmethod
def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.

def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.

On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]


def get_include_dirs() -> Sequence[str]:
"""Gets the include directory for compiling against exported C libraries.

Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]


# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
# the pybind memory model (i.e. there is not a Python reference to the object
# in the scope of the base class __init__).
#
# For #1, we:
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
# d. If the module has a 'disable_multithreading' attribute, it will be
# taken as a boolean. If it is True for any initializer, then the
# default behavior of enabling multithreading on the context
# will be suppressed. This complies with the original behavior of all
# contexts being created with multithreading enabled while allowing
# this behavior to be changed if needed (i.e. if a context_init_hook
# explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.
def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir

logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
disable_multithreading = False

def process_initializer_module(module_name):
nonlocal disable_multithreading
On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]

@staticmethod
def get_include_dirs() -> Sequence[str]:
"""Gets the include directory for compiling against exported C libraries.

Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]

# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
# the pybind memory model (i.e. there is not a Python reference to the object
# in the scope of the base class __init__).
#
# For #1, we:
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
# d. If the module has a 'disable_multithreading' attribute, it will be
# taken as a boolean. If it is True for any initializer, then the
# default behavior of enabling multithreading on the context
# will be suppressed. This complies with the original behavior of all
# contexts being created with multithreading enabled while allowing
# this behavior to be changed if needed (i.e. if a context_init_hook
# explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.

__registry = ir.DialectRegistry()
__post_init_hooks = []
__disable_multithreading = False
from . import _mlir as _mlir

def __get_registry(self):
return self.__registry

def process_c_ext_module(self, module_name):
try:
m = importlib.import_module(f".{module_name}", __name__)
m = importlib.import_module(f"{module_name}", __name__)
except ModuleNotFoundError:
return False
except ImportError:
Expand All @@ -84,48 +98,51 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(registry)
m.register_dialects(self.__get_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
self.__post_init_hooks.append(m.context_init_hook)
if hasattr(m, "disable_multithreading"):
if bool(m.disable_multithreading):
logger.debug("Disabling multi-threading for context")
disable_multithreading = True
self.__disable_multithreading = True
return True

# If _mlirRegisterEverything is built, then include it as an initializer
# module.
init_module = None
if process_initializer_module("_mlirRegisterEverything"):
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)

# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"_site_initialize_{i}"
if not process_initializer_module(module_name):
break

class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(registry)
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
)
init_module.register_llvm_translations(self)
def __init__(self):
# If _mlirRegisterEverything is built, then include it as an initializer
# module.
init_module = None
if self.process_c_ext_module("._mlirRegisterEverything"):
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)

# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"._site_initialize_{i}"
if not self.process_c_ext_module(module_name):
break

that = self

class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(that._M__get_registry())
for hook in that._M__post_init_hooks:
hook(self)
if not that._M__disable_multithreading:
self.enable_multithreading(True)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
)
init_module.register_llvm_translations(self)

ir.Context = Context
ir.Context = Context

class MLIRError(Exception):
"""
Expand Down Expand Up @@ -162,4 +179,4 @@ def __str__(self):
ir.MLIRError = MLIRError


_site_initialize()
sys.modules[__name__] = _M()
6 changes: 0 additions & 6 deletions mlir/python/mlir/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,3 @@
TestTensorValue,
TestIntegerRankedTensorType,
)


def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest

_mlirPythonTest.register_python_test_dialect(context, load)
17 changes: 4 additions & 13 deletions mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir import _mlir_libs

_mlir_libs.process_c_ext_module("mlir._mlir_libs._mlirPythonTest")

from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
Expand All @@ -17,7 +21,6 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
Expand Down Expand Up @@ -138,7 +141,6 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
Expand Down Expand Up @@ -215,7 +217,6 @@ def attrBuilder():
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
Expand Down Expand Up @@ -260,7 +261,6 @@ def inferReturnTypes():
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
Expand Down Expand Up @@ -295,7 +295,6 @@ def resultTypesDefinedByTraits():
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

module = Module.create()
with InsertionPoint(module.body):
Expand All @@ -312,7 +311,6 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
Expand Down Expand Up @@ -350,7 +348,6 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
test.register_python_test_dialect(ctx)
a = test.TestType.get()
# CHECK: !python_test.test_type
print(a)
Expand Down Expand Up @@ -397,8 +394,6 @@ def testCustomType():
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

i8 = IntegerType.get_signless(8)

class Tensor(test.TestTensorValue):
Expand Down Expand Up @@ -436,7 +431,6 @@ def __str__(self):
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
Expand Down Expand Up @@ -488,8 +482,6 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)

a = test.TestType.get()
assert a.typeid is not None

Expand Down Expand Up @@ -542,7 +534,6 @@ def type_caster(pytype):
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/python/lib/PythonTestModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);

m.def(
"register_dialects",
[](MlirDialectRegistry registry) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
},
py::arg("registry"));

mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
Expand Down