@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
56
56
#
57
57
# This facility allows downstreams to customize Context creation to their
58
58
# needs.
59
- def _site_initialize ():
60
- import importlib
61
- import itertools
62
- import logging
63
- from ._mlir import ir
64
-
65
- logger = logging .getLogger (__name__ )
66
- registry = ir .DialectRegistry ()
67
- post_init_hooks = []
68
- disable_multithreading = False
69
-
70
- def process_initializer_module (module_name ):
71
- nonlocal disable_multithreading
72
- try :
73
- m = importlib .import_module (f".{ module_name } " , __name__ )
74
- except ModuleNotFoundError :
75
- return False
76
- except ImportError :
77
- message = (
78
- f"Error importing mlir initializer { module_name } . This may "
79
- "happen in unclean incremental builds but is likely a real bug if "
80
- "encountered otherwise and the MLIR Python API may not function."
59
+ import importlib
60
+ import itertools
61
+ import logging
62
+ from ._mlir import ir
63
+
64
+ logger = logging .getLogger (__name__ )
65
+ registry = ir .DialectRegistry ()
66
+ post_init_hooks = []
67
+ disable_multithreading = False
68
+
69
+
70
+ def get_registry ():
71
+ return registry
72
+
73
+
74
+ def process_initializer_module (module_name ):
75
+ global disable_multithreading
76
+ try :
77
+ m = importlib .import_module (f".{ module_name } " , __name__ )
78
+ except ModuleNotFoundError :
79
+ return False
80
+ except ImportError :
81
+ message = (
82
+ f"Error importing mlir initializer { module_name } . This may "
83
+ "happen in unclean incremental builds but is likely a real bug if "
84
+ "encountered otherwise and the MLIR Python API may not function."
85
+ )
86
+ logger .warning (message , exc_info = True )
87
+
88
+ logger .debug ("Initializing MLIR with module: %s" , module_name )
89
+ if hasattr (m , "register_dialects" ):
90
+ logger .debug ("Registering dialects from initializer %r" , m )
91
+ m .register_dialects (registry )
92
+ if hasattr (m , "context_init_hook" ):
93
+ logger .debug ("Adding context init hook from %r" , m )
94
+ post_init_hooks .append (m .context_init_hook )
95
+ if hasattr (m , "disable_multithreading" ):
96
+ if bool (m .disable_multithreading ):
97
+ logger .debug ("Disabling multi-threading for context" )
98
+ disable_multithreading = True
99
+ return True
100
+
101
+
102
+ # If _mlirRegisterEverything is built, then include it as an initializer
103
+ # module.
104
+ init_module = None
105
+ if process_initializer_module ("_mlirRegisterEverything" ):
106
+ init_module = importlib .import_module (f"._mlirRegisterEverything" , __name__ )
107
+
108
+ # Load all _site_initialize_{i} modules, where 'i' is a number starting
109
+ # at 0.
110
+ for i in itertools .count ():
111
+ module_name = f"_site_initialize_{ i } "
112
+ if not process_initializer_module (module_name ):
113
+ break
114
+
115
+
116
+ class Context (ir ._BaseContext ):
117
+ def __init__ (self , * args , ** kwargs ):
118
+ super ().__init__ (* args , ** kwargs )
119
+ self .append_dialect_registry (get_registry ())
120
+ for hook in post_init_hooks :
121
+ hook (self )
122
+ if not disable_multithreading :
123
+ self .enable_multithreading (True )
124
+ # TODO: There is some debate about whether we should eagerly load
125
+ # all dialects. It is being done here in order to preserve existing
126
+ # behavior. See: https://github.com/llvm/llvm-project/issues/56037
127
+ self .load_all_available_dialects ()
128
+ if init_module :
129
+ logger .debug ("Registering translations from initializer %r" , init_module )
130
+ init_module .register_llvm_translations (self )
131
+
132
+
133
+ ir .Context = Context
134
+
135
+
136
+ class MLIRError (Exception ):
137
+ """
138
+ An exception with diagnostic information. Has the following fields:
139
+ message: str
140
+ error_diagnostics: List[ir.DiagnosticInfo]
141
+ """
142
+
143
+ def __init__ (self , message , error_diagnostics ):
144
+ self .message = message
145
+ self .error_diagnostics = error_diagnostics
146
+ super ().__init__ (message , error_diagnostics )
147
+
148
+ def __str__ (self ):
149
+ s = self .message
150
+ if self .error_diagnostics :
151
+ s += ":"
152
+ for diag in self .error_diagnostics :
153
+ s += (
154
+ "\n error: "
155
+ + str (diag .location )[4 :- 1 ]
156
+ + ": "
157
+ + diag .message .replace ("\n " , "\n " )
81
158
)
82
- logger .warning (message , exc_info = True )
83
-
84
- logger .debug ("Initializing MLIR with module: %s" , module_name )
85
- if hasattr (m , "register_dialects" ):
86
- logger .debug ("Registering dialects from initializer %r" , m )
87
- m .register_dialects (registry )
88
- if hasattr (m , "context_init_hook" ):
89
- logger .debug ("Adding context init hook from %r" , m )
90
- post_init_hooks .append (m .context_init_hook )
91
- if hasattr (m , "disable_multithreading" ):
92
- if bool (m .disable_multithreading ):
93
- logger .debug ("Disabling multi-threading for context" )
94
- disable_multithreading = True
95
- return True
96
-
97
- # If _mlirRegisterEverything is built, then include it as an initializer
98
- # module.
99
- init_module = None
100
- if process_initializer_module ("_mlirRegisterEverything" ):
101
- init_module = importlib .import_module (f"._mlirRegisterEverything" , __name__ )
102
-
103
- # Load all _site_initialize_{i} modules, where 'i' is a number starting
104
- # at 0.
105
- for i in itertools .count ():
106
- module_name = f"_site_initialize_{ i } "
107
- if not process_initializer_module (module_name ):
108
- break
109
-
110
- class Context (ir ._BaseContext ):
111
- def __init__ (self , * args , ** kwargs ):
112
- super ().__init__ (* args , ** kwargs )
113
- self .append_dialect_registry (registry )
114
- for hook in post_init_hooks :
115
- hook (self )
116
- if not disable_multithreading :
117
- self .enable_multithreading (True )
118
- # TODO: There is some debate about whether we should eagerly load
119
- # all dialects. It is being done here in order to preserve existing
120
- # behavior. See: https://github.com/llvm/llvm-project/issues/56037
121
- self .load_all_available_dialects ()
122
- if init_module :
123
- logger .debug (
124
- "Registering translations from initializer %r" , init_module
125
- )
126
- init_module .register_llvm_translations (self )
127
-
128
- ir .Context = Context
129
-
130
- class MLIRError (Exception ):
131
- """
132
- An exception with diagnostic information. Has the following fields:
133
- message: str
134
- error_diagnostics: List[ir.DiagnosticInfo]
135
- """
136
-
137
- def __init__ (self , message , error_diagnostics ):
138
- self .message = message
139
- self .error_diagnostics = error_diagnostics
140
- super ().__init__ (message , error_diagnostics )
141
-
142
- def __str__ (self ):
143
- s = self .message
144
- if self .error_diagnostics :
145
- s += ":"
146
- for diag in self .error_diagnostics :
159
+ for note in diag .notes :
147
160
s += (
148
- "\n error : "
149
- + str (diag .location )[4 :- 1 ]
161
+ "\n note : "
162
+ + str (note .location )[4 :- 1 ]
150
163
+ ": "
151
- + diag .message .replace ("\n " , "\n " )
164
+ + note .message .replace ("\n " , "\n " )
152
165
)
153
- for note in diag .notes :
154
- s += (
155
- "\n note: "
156
- + str (note .location )[4 :- 1 ]
157
- + ": "
158
- + note .message .replace ("\n " , "\n " )
159
- )
160
- return s
161
-
162
- ir .MLIRError = MLIRError
166
+ return s
163
167
164
168
165
- _site_initialize ()
169
+ ir . MLIRError = MLIRError
0 commit comments