Skip to content

Commit 8b6c773

Browse files
committed
[mlir][python] set the registry free
1 parent 3f743fd commit 8b6c773

File tree

1 file changed

+105
-101
lines changed

1 file changed

+105
-101
lines changed

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 105 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
5656
#
5757
# This facility allows downstreams to customize Context creation to their
5858
# needs.
59-
def _site_initialize():
60-
import importlib
61-
import itertools
62-
import logging
63-
from ._mlir import ir
64-
65-
logger = logging.getLogger(__name__)
66-
registry = ir.DialectRegistry()
67-
post_init_hooks = []
68-
disable_multithreading = False
69-
70-
def process_initializer_module(module_name):
71-
nonlocal disable_multithreading
72-
try:
73-
m = importlib.import_module(f".{module_name}", __name__)
74-
except ModuleNotFoundError:
75-
return False
76-
except ImportError:
77-
message = (
78-
f"Error importing mlir initializer {module_name}. This may "
79-
"happen in unclean incremental builds but is likely a real bug if "
80-
"encountered otherwise and the MLIR Python API may not function."
59+
import importlib
60+
import itertools
61+
import logging
62+
from ._mlir import ir
63+
64+
logger = logging.getLogger(__name__)
65+
registry = ir.DialectRegistry()
66+
post_init_hooks = []
67+
disable_multithreading = False
68+
69+
70+
def get_registry():
71+
return registry
72+
73+
74+
def process_initializer_module(module_name):
75+
global disable_multithreading
76+
try:
77+
m = importlib.import_module(f".{module_name}", __name__)
78+
except ModuleNotFoundError:
79+
return False
80+
except ImportError:
81+
message = (
82+
f"Error importing mlir initializer {module_name}. This may "
83+
"happen in unclean incremental builds but is likely a real bug if "
84+
"encountered otherwise and the MLIR Python API may not function."
85+
)
86+
logger.warning(message, exc_info=True)
87+
88+
logger.debug("Initializing MLIR with module: %s", module_name)
89+
if hasattr(m, "register_dialects"):
90+
logger.debug("Registering dialects from initializer %r", m)
91+
m.register_dialects(registry)
92+
if hasattr(m, "context_init_hook"):
93+
logger.debug("Adding context init hook from %r", m)
94+
post_init_hooks.append(m.context_init_hook)
95+
if hasattr(m, "disable_multithreading"):
96+
if bool(m.disable_multithreading):
97+
logger.debug("Disabling multi-threading for context")
98+
disable_multithreading = True
99+
return True
100+
101+
102+
# If _mlirRegisterEverything is built, then include it as an initializer
103+
# module.
104+
init_module = None
105+
if process_initializer_module("_mlirRegisterEverything"):
106+
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
107+
108+
# Load all _site_initialize_{i} modules, where 'i' is a number starting
109+
# at 0.
110+
for i in itertools.count():
111+
module_name = f"_site_initialize_{i}"
112+
if not process_initializer_module(module_name):
113+
break
114+
115+
116+
class Context(ir._BaseContext):
117+
def __init__(self, *args, **kwargs):
118+
super().__init__(*args, **kwargs)
119+
self.append_dialect_registry(get_registry())
120+
for hook in post_init_hooks:
121+
hook(self)
122+
if not disable_multithreading:
123+
self.enable_multithreading(True)
124+
# TODO: There is some debate about whether we should eagerly load
125+
# all dialects. It is being done here in order to preserve existing
126+
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
127+
self.load_all_available_dialects()
128+
if init_module:
129+
logger.debug("Registering translations from initializer %r", init_module)
130+
init_module.register_llvm_translations(self)
131+
132+
133+
ir.Context = Context
134+
135+
136+
class MLIRError(Exception):
137+
"""
138+
An exception with diagnostic information. Has the following fields:
139+
message: str
140+
error_diagnostics: List[ir.DiagnosticInfo]
141+
"""
142+
143+
def __init__(self, message, error_diagnostics):
144+
self.message = message
145+
self.error_diagnostics = error_diagnostics
146+
super().__init__(message, error_diagnostics)
147+
148+
def __str__(self):
149+
s = self.message
150+
if self.error_diagnostics:
151+
s += ":"
152+
for diag in self.error_diagnostics:
153+
s += (
154+
"\nerror: "
155+
+ str(diag.location)[4:-1]
156+
+ ": "
157+
+ diag.message.replace("\n", "\n ")
81158
)
82-
logger.warning(message, exc_info=True)
83-
84-
logger.debug("Initializing MLIR with module: %s", module_name)
85-
if hasattr(m, "register_dialects"):
86-
logger.debug("Registering dialects from initializer %r", m)
87-
m.register_dialects(registry)
88-
if hasattr(m, "context_init_hook"):
89-
logger.debug("Adding context init hook from %r", m)
90-
post_init_hooks.append(m.context_init_hook)
91-
if hasattr(m, "disable_multithreading"):
92-
if bool(m.disable_multithreading):
93-
logger.debug("Disabling multi-threading for context")
94-
disable_multithreading = True
95-
return True
96-
97-
# If _mlirRegisterEverything is built, then include it as an initializer
98-
# module.
99-
init_module = None
100-
if process_initializer_module("_mlirRegisterEverything"):
101-
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
102-
103-
# Load all _site_initialize_{i} modules, where 'i' is a number starting
104-
# at 0.
105-
for i in itertools.count():
106-
module_name = f"_site_initialize_{i}"
107-
if not process_initializer_module(module_name):
108-
break
109-
110-
class Context(ir._BaseContext):
111-
def __init__(self, *args, **kwargs):
112-
super().__init__(*args, **kwargs)
113-
self.append_dialect_registry(registry)
114-
for hook in post_init_hooks:
115-
hook(self)
116-
if not disable_multithreading:
117-
self.enable_multithreading(True)
118-
# TODO: There is some debate about whether we should eagerly load
119-
# all dialects. It is being done here in order to preserve existing
120-
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
121-
self.load_all_available_dialects()
122-
if init_module:
123-
logger.debug(
124-
"Registering translations from initializer %r", init_module
125-
)
126-
init_module.register_llvm_translations(self)
127-
128-
ir.Context = Context
129-
130-
class MLIRError(Exception):
131-
"""
132-
An exception with diagnostic information. Has the following fields:
133-
message: str
134-
error_diagnostics: List[ir.DiagnosticInfo]
135-
"""
136-
137-
def __init__(self, message, error_diagnostics):
138-
self.message = message
139-
self.error_diagnostics = error_diagnostics
140-
super().__init__(message, error_diagnostics)
141-
142-
def __str__(self):
143-
s = self.message
144-
if self.error_diagnostics:
145-
s += ":"
146-
for diag in self.error_diagnostics:
159+
for note in diag.notes:
147160
s += (
148-
"\nerror: "
149-
+ str(diag.location)[4:-1]
161+
"\n note: "
162+
+ str(note.location)[4:-1]
150163
+ ": "
151-
+ diag.message.replace("\n", "\n ")
164+
+ note.message.replace("\n", "\n ")
152165
)
153-
for note in diag.notes:
154-
s += (
155-
"\n note: "
156-
+ str(note.location)[4:-1]
157-
+ ": "
158-
+ note.message.replace("\n", "\n ")
159-
)
160-
return s
161-
162-
ir.MLIRError = MLIRError
166+
return s
163167

164168

165-
_site_initialize()
169+
ir.MLIRError = MLIRError

0 commit comments

Comments
 (0)