diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 6ce77b4cb93f6..f5dfd1edf5a3e 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -2,75 +2,89 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Sequence - +import importlib +import itertools +import logging import os +import sys +from typing import Sequence + +from ._mlir import ir + _this_dir = os.path.dirname(__file__) +logger = logging.getLogger(__name__) + +_path = __path__ +_spec = __spec__ +_name = __name__ + + +class _M: + __path__ = _path + __spec__ = _spec + __name__ = _name + + @staticmethod + def get_lib_dirs() -> Sequence[str]: + """Gets the lib directory for linking to shared libraries. -def get_lib_dirs() -> Sequence[str]: - """Gets the lib directory for linking to shared libraries. - - On some platforms, the package may need to be built specially to export - development libraries. - """ - return [_this_dir] - - -def get_include_dirs() -> Sequence[str]: - """Gets the include directory for compiling against exported C libraries. - - Depending on how the package was build, development C libraries may or may - not be present. - """ - return [os.path.join(_this_dir, "include")] - - -# Perform Python level site initialization. This involves: -# 1. Attempting to load initializer modules, specific to the distribution. -# 2. Defining the concrete mlir.ir.Context that does site specific -# initialization. -# -# Aside from just being far more convenient to do this at the Python level, -# it is actually quite hard/impossible to have such __init__ hooks, given -# the pybind memory model (i.e. there is not a Python reference to the object -# in the scope of the base class __init__). -# -# For #1, we: -# a. Probe for modules named '_mlirRegisterEverything' and -# '_site_initialize_{i}', where 'i' is a number starting at zero and -# proceeding so long as a module with the name is found. -# b. If the module has a 'register_dialects' attribute, it will be called -# immediately with a DialectRegistry to populate. -# c. If the module has a 'context_init_hook', it will be added to a list -# of callbacks that are invoked as the last step of Context -# initialization (and passed the Context under construction). -# d. If the module has a 'disable_multithreading' attribute, it will be -# taken as a boolean. If it is True for any initializer, then the -# default behavior of enabling multithreading on the context -# will be suppressed. This complies with the original behavior of all -# contexts being created with multithreading enabled while allowing -# this behavior to be changed if needed (i.e. if a context_init_hook -# explicitly sets up multithreading). -# -# This facility allows downstreams to customize Context creation to their -# needs. -def _site_initialize(): - import importlib - import itertools - import logging - from ._mlir import ir - - logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() - post_init_hooks = [] - disable_multithreading = False - - def process_initializer_module(module_name): - nonlocal disable_multithreading + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] + + @staticmethod + def get_include_dirs() -> Sequence[str]: + """Gets the include directory for compiling against exported C libraries. + + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] + + # Perform Python level site initialization. This involves: + # 1. Attempting to load initializer modules, specific to the distribution. + # 2. Defining the concrete mlir.ir.Context that does site specific + # initialization. + # + # Aside from just being far more convenient to do this at the Python level, + # it is actually quite hard/impossible to have such __init__ hooks, given + # the pybind memory model (i.e. there is not a Python reference to the object + # in the scope of the base class __init__). + # + # For #1, we: + # a. Probe for modules named '_mlirRegisterEverything' and + # '_site_initialize_{i}', where 'i' is a number starting at zero and + # proceeding so long as a module with the name is found. + # b. If the module has a 'register_dialects' attribute, it will be called + # immediately with a DialectRegistry to populate. + # c. If the module has a 'context_init_hook', it will be added to a list + # of callbacks that are invoked as the last step of Context + # initialization (and passed the Context under construction). + # d. If the module has a 'disable_multithreading' attribute, it will be + # taken as a boolean. If it is True for any initializer, then the + # default behavior of enabling multithreading on the context + # will be suppressed. This complies with the original behavior of all + # contexts being created with multithreading enabled while allowing + # this behavior to be changed if needed (i.e. if a context_init_hook + # explicitly sets up multithreading). + # + # This facility allows downstreams to customize Context creation to their + # needs. + + __registry = ir.DialectRegistry() + __post_init_hooks = [] + __disable_multithreading = False + from . import _mlir as _mlir + + def __get_registry(self): + return self.__registry + + def process_c_ext_module(self, module_name): try: - m = importlib.import_module(f".{module_name}", __name__) + m = importlib.import_module(f"{module_name}", __name__) except ModuleNotFoundError: return False except ImportError: @@ -84,48 +98,51 @@ def process_initializer_module(module_name): logger.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) + m.register_dialects(self.__get_registry()) if hasattr(m, "context_init_hook"): logger.debug("Adding context init hook from %r", m) - post_init_hooks.append(m.context_init_hook) + self.__post_init_hooks.append(m.context_init_hook) if hasattr(m, "disable_multithreading"): if bool(m.disable_multithreading): logger.debug("Disabling multi-threading for context") - disable_multithreading = True + self.__disable_multithreading = True return True - # If _mlirRegisterEverything is built, then include it as an initializer - # module. - init_module = None - if process_initializer_module("_mlirRegisterEverything"): - init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) - - # Load all _site_initialize_{i} modules, where 'i' is a number starting - # at 0. - for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not process_initializer_module(module_name): - break - - class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) - for hook in post_init_hooks: - hook(self) - if not disable_multithreading: - self.enable_multithreading(True) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() - if init_module: - logger.debug( - "Registering translations from initializer %r", init_module - ) - init_module.register_llvm_translations(self) + def __init__(self): + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + init_module = None + if self.process_c_ext_module("._mlirRegisterEverything"): + init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"._site_initialize_{i}" + if not self.process_c_ext_module(module_name): + break + + that = self + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(that._M__get_registry()) + for hook in that._M__post_init_hooks: + hook(self) + if not that._M__disable_multithreading: + self.enable_multithreading(True) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + if init_module: + logger.debug( + "Registering translations from initializer %r", init_module + ) + init_module.register_llvm_translations(self) - ir.Context = Context + ir.Context = Context class MLIRError(Exception): """ @@ -162,4 +179,4 @@ def __str__(self): ir.MLIRError = MLIRError -_site_initialize() +sys.modules[__name__] = _M() diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 6579e02d8549e..8b4f718d8a53b 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -9,9 +9,3 @@ TestTensorValue, TestIntegerRankedTensorType, ) - - -def register_python_test_dialect(context, load=True): - from .._mlir_libs import _mlirPythonTest - - _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index f313a400b73c0..309de8037049c 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,5 +1,9 @@ # RUN: %PYTHON %s | FileCheck %s +from mlir import _mlir_libs + +_mlir_libs.process_c_ext_module("mlir._mlir_libs._mlirPythonTest") + from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.python_test as test @@ -17,7 +21,6 @@ def run(f): @run def testAttributes(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # # Check op construction with attributes. # @@ -138,7 +141,6 @@ def testAttributes(): @run def attrBuilder(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # CHECK: python_test.attributes_op op = test.AttributesOp( # CHECK-DAG: x_affinemap = affine_map<() -> (2)> @@ -215,7 +217,6 @@ def attrBuilder(): @run def inferReturnTypes(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): op = test.InferResultsOp() @@ -260,7 +261,6 @@ def inferReturnTypes(): @run def resultTypesDefinedByTraits(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): inferred = test.InferResultsOp() @@ -295,7 +295,6 @@ def resultTypesDefinedByTraits(): @run def testOptionalOperandOp(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): @@ -312,7 +311,6 @@ def testOptionalOperandOp(): @run def testCustomAttribute(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestAttr.get() # CHECK: #python_test.test_attr print(a) @@ -350,7 +348,6 @@ def testCustomAttribute(): @run def testCustomType(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestType.get() # CHECK: !python_test.test_type print(a) @@ -397,8 +394,6 @@ def testCustomType(): # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - i8 = IntegerType.get_signless(8) class Tensor(test.TestTensorValue): @@ -436,7 +431,6 @@ def __str__(self): @run def inferReturnTypeComponents(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): @@ -488,8 +482,6 @@ def inferReturnTypeComponents(): @run def testCustomTypeTypeCaster(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - a = test.TestType.get() assert a.typeid is not None @@ -542,7 +534,6 @@ def type_caster(pytype): @run def testInferTypeOpInterface(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): i64 = IntegerType.get_signless(64) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index aff414894cb82..9e7decefa7166 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) { }, py::arg("context"), py::arg("load") = true); + m.def( + "register_dialects", + [](MlirDialectRegistry registry) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleInsertDialect(pythonTestDialect, registry); + }, + py::arg("registry")); + mlir_attribute_subclass(m, "TestAttr", mlirAttributeIsAPythonTestTestAttribute) .def_classmethod(