Skip to content

Commit 3f1486f

Browse files
committed
Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"
Breaks on 3.8, rolling back to avoid breakage while fixing. This reverts commit 9dee7c4.
1 parent 1d2eea9 commit 3f1486f

File tree

9 files changed

+16
-649
lines changed

9 files changed

+16
-649
lines changed

mlir/cmake/modules/AddMLIRPython.cmake

+1-20
Original file line numberDiff line numberDiff line change
@@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname)
668668
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
669669
nanobind_add_module(${libname}
670670
NB_DOMAIN mlir
671-
FREE_THREADED
672671
${ARG_SOURCES}
673672
)
674673

675674
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
676675
# Avoids warnings from upstream nanobind.
677-
set(nanobind_target "nanobind-static")
678-
if (NOT TARGET ${nanobind_target})
679-
# Get correct nanobind target name: nanobind-static-ft or something else
680-
# It is set by nanobind_add_module function according to the passed options
681-
get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
682-
683-
# Iterate over the list of targets
684-
foreach(target ${all_targets})
685-
# Check if the target name matches the given string
686-
if("${target}" MATCHES "nanobind-")
687-
set(nanobind_target "${target}")
688-
endif()
689-
endforeach()
690-
691-
if (NOT TARGET ${nanobind_target})
692-
message(FATAL_ERROR "Could not find nanobind target to set compile options to")
693-
endif()
694-
endif()
695-
target_compile_options(${nanobind_target}
676+
target_compile_options(nanobind-static
696677
PRIVATE
697678
-Wno-cast-qual
698679
-Wno-zero-length-array

mlir/docs/Bindings/Python.md

-40
Original file line numberDiff line numberDiff line change
@@ -1187,43 +1187,3 @@ or nanobind and
11871187
utilities to connect to the rest of Python API. The bindings can be located in a
11881188
separate module or in the same module as attributes and types, and
11891189
loaded along with the dialect.
1190-
1191-
## Free-threading (No-GIL) support
1192-
1193-
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
1194-
1195-
MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
1196-
1197-
```python
1198-
# python3.13t example.py
1199-
import concurrent.futures
1200-
1201-
import mlir.dialects.arith as arith
1202-
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
1203-
1204-
1205-
def func(py_value):
1206-
with Context() as ctx:
1207-
module = Module.create(loc=Location.file("foo.txt", 0, 0))
1208-
1209-
dtype = IntegerType.get_signless(64)
1210-
with InsertionPoint(module.body), Location.name("a"):
1211-
arith.constant(dtype, py_value)
1212-
1213-
return module
1214-
1215-
1216-
num_workers = 8
1217-
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
1218-
futures = []
1219-
for i in range(num_workers):
1220-
futures.append(executor.submit(func, i))
1221-
assert len(list(f.result() for f in futures)) == num_workers
1222-
```
1223-
1224-
The exceptions to the free-threading compatibility:
1225-
- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
1226-
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
1227-
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
1228-
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
1229-
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.

mlir/lib/Bindings/Python/Globals.h

+1-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ namespace mlir {
2424
namespace python {
2525

2626
/// Globals that are always accessible once the extension has been initialized.
27-
/// Methods of this class are thread-safe.
2827
class PyGlobals {
2928
public:
3029
PyGlobals();
@@ -38,18 +37,12 @@ class PyGlobals {
3837

3938
/// Get and set the list of parent modules to search for dialect
4039
/// implementation classes.
41-
std::vector<std::string> getDialectSearchPrefixes() {
42-
nanobind::ft_lock_guard lock(mutex);
40+
std::vector<std::string> &getDialectSearchPrefixes() {
4341
return dialectSearchPrefixes;
4442
}
4543
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
46-
nanobind::ft_lock_guard lock(mutex);
4744
dialectSearchPrefixes.swap(newValues);
4845
}
49-
void addDialectSearchPrefix(std::string value) {
50-
nanobind::ft_lock_guard lock(mutex);
51-
dialectSearchPrefixes.push_back(std::move(value));
52-
}
5346

5447
/// Loads a python module corresponding to the given dialect namespace.
5548
/// No-ops if the module has already been loaded or is not found. Raises
@@ -116,9 +109,6 @@ class PyGlobals {
116109

117110
private:
118111
static PyGlobals *instance;
119-
120-
nanobind::ft_mutex mutex;
121-
122112
/// Module name prefixes to search under for dialect implementation modules.
123113
std::vector<std::string> dialectSearchPrefixes;
124114
/// Map of dialect namespace to external dialect class object.

mlir/lib/Bindings/Python/IRCore.cpp

+4-27
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
243243

244244
/// Wrapper for the global LLVM debugging flag.
245245
struct PyGlobalDebugFlag {
246-
static void set(nb::object &o, bool enable) {
247-
nb::ft_lock_guard lock(mutex);
248-
mlirEnableGlobalDebug(enable);
249-
}
246+
static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
250247

251-
static bool get(const nb::object &) {
252-
nb::ft_lock_guard lock(mutex);
253-
return mlirIsGlobalDebugEnabled();
254-
}
248+
static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
255249

256250
static void bind(nb::module_ &m) {
257251
// Debug flags.
@@ -261,7 +255,6 @@ struct PyGlobalDebugFlag {
261255
.def_static(
262256
"set_types",
263257
[](const std::string &type) {
264-
nb::ft_lock_guard lock(mutex);
265258
mlirSetGlobalDebugType(type.c_str());
266259
},
267260
"types"_a, "Sets specific debug types to be produced by LLVM")
@@ -270,17 +263,11 @@ struct PyGlobalDebugFlag {
270263
pointers.reserve(types.size());
271264
for (const std::string &str : types)
272265
pointers.push_back(str.c_str());
273-
nb::ft_lock_guard lock(mutex);
274266
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
275267
});
276268
}
277-
278-
private:
279-
static nb::ft_mutex mutex;
280269
};
281270

282-
nb::ft_mutex PyGlobalDebugFlag::mutex;
283-
284271
struct PyAttrBuilderMap {
285272
static bool dunderContains(const std::string &attributeKind) {
286273
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -619,7 +606,6 @@ class PyOpOperandIterator {
619606

620607
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
621608
nb::gil_scoped_acquire acquire;
622-
nb::ft_lock_guard lock(live_contexts_mutex);
623609
auto &liveContexts = getLiveContexts();
624610
liveContexts[context.ptr] = this;
625611
}
@@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() {
629615
// forContext method, which always puts the associated handle into
630616
// liveContexts.
631617
nb::gil_scoped_acquire acquire;
632-
{
633-
nb::ft_lock_guard lock(live_contexts_mutex);
634-
getLiveContexts().erase(context.ptr);
635-
}
618+
getLiveContexts().erase(context.ptr);
636619
mlirContextDestroy(context);
637620
}
638621

@@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
649632

650633
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
651634
nb::gil_scoped_acquire acquire;
652-
nb::ft_lock_guard lock(live_contexts_mutex);
653635
auto &liveContexts = getLiveContexts();
654636
auto it = liveContexts.find(context.ptr);
655637
if (it == liveContexts.end()) {
@@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
665647
return PyMlirContextRef(it->second, std::move(pyRef));
666648
}
667649

668-
nb::ft_mutex PyMlirContext::live_contexts_mutex;
669-
670650
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
671651
static LiveContextMap liveContexts;
672652
return liveContexts;
673653
}
674654

675-
size_t PyMlirContext::getLiveCount() {
676-
nb::ft_lock_guard lock(live_contexts_mutex);
677-
return getLiveContexts().size();
678-
}
655+
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
679656

680657
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
681658

mlir/lib/Bindings/Python/IRModule.cpp

+2-16
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@ PyGlobals::PyGlobals() {
3838
PyGlobals::~PyGlobals() { instance = nullptr; }
3939

4040
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41-
{
42-
nb::ft_lock_guard lock(mutex);
43-
if (loadedDialectModules.contains(dialectNamespace))
44-
return true;
45-
}
41+
if (loadedDialectModules.contains(dialectNamespace))
42+
return true;
4643
// Since re-entrancy is possible, make a copy of the search prefixes.
4744
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
4845
nb::object loaded = nb::none();
@@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
6562
return false;
6663
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
6764
// may have occurred, which may do anything.
68-
nb::ft_lock_guard lock(mutex);
6965
loadedDialectModules.insert(dialectNamespace);
7066
return true;
7167
}
7268

7369
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
7470
nb::callable pyFunc, bool replace) {
75-
nb::ft_lock_guard lock(mutex);
7671
nb::object &found = attributeBuilderMap[attributeKind];
7772
if (found && !replace) {
7873
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
8681

8782
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
8883
nb::callable typeCaster, bool replace) {
89-
nb::ft_lock_guard lock(mutex);
9084
nb::object &found = typeCasterMap[mlirTypeID];
9185
if (found && !replace)
9286
throw std::runtime_error("Type caster is already registered with caster: " +
@@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
9690

9791
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
9892
nb::callable valueCaster, bool replace) {
99-
nb::ft_lock_guard lock(mutex);
10093
nb::object &found = valueCasterMap[mlirTypeID];
10194
if (found && !replace)
10295
throw std::runtime_error("Value caster is already registered: " +
@@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
10699

107100
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
108101
nb::object pyClass) {
109-
nb::ft_lock_guard lock(mutex);
110102
nb::object &found = dialectClassMap[dialectNamespace];
111103
if (found) {
112104
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
118110

119111
void PyGlobals::registerOperationImpl(const std::string &operationName,
120112
nb::object pyClass, bool replace) {
121-
nb::ft_lock_guard lock(mutex);
122113
nb::object &found = operationClassMap[operationName];
123114
if (found && !replace) {
124115
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
130121

131122
std::optional<nb::callable>
132123
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133-
nb::ft_lock_guard lock(mutex);
134124
const auto foundIt = attributeBuilderMap.find(attributeKind);
135125
if (foundIt != attributeBuilderMap.end()) {
136126
assert(foundIt->second && "attribute builder is defined");
@@ -143,7 +133,6 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
143133
MlirDialect dialect) {
144134
// Try to load dialect module.
145135
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
146-
nb::ft_lock_guard lock(mutex);
147136
const auto foundIt = typeCasterMap.find(mlirTypeID);
148137
if (foundIt != typeCasterMap.end()) {
149138
assert(foundIt->second && "type caster is defined");
@@ -156,7 +145,6 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
156145
MlirDialect dialect) {
157146
// Try to load dialect module.
158147
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
159-
nb::ft_lock_guard lock(mutex);
160148
const auto foundIt = valueCasterMap.find(mlirTypeID);
161149
if (foundIt != valueCasterMap.end()) {
162150
assert(foundIt->second && "value caster is defined");
@@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
170158
// Make sure dialect module is loaded.
171159
if (!loadDialectModule(dialectNamespace))
172160
return std::nullopt;
173-
nb::ft_lock_guard lock(mutex);
174161
const auto foundIt = dialectClassMap.find(dialectNamespace);
175162
if (foundIt != dialectClassMap.end()) {
176163
assert(foundIt->second && "dialect class is defined");
@@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
188175
if (!loadDialectModule(dialectNamespace))
189176
return std::nullopt;
190177

191-
nb::ft_lock_guard lock(mutex);
192178
auto foundIt = operationClassMap.find(operationName);
193179
if (foundIt != operationClassMap.end()) {
194180
assert(foundIt->second && "OpView is defined");

mlir/lib/Bindings/Python/IRModule.h

-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ class PyMlirContext {
260260
// Note that this holds a handle, which does not imply ownership.
261261
// Mappings will be removed when the context is destructed.
262262
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
263-
static nanobind::ft_mutex live_contexts_mutex;
264263
static LiveContextMap &getLiveContexts();
265264

266265
// Interns all live modules associated with this context. Modules tracked

mlir/lib/Bindings/Python/MainModule.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) {
3030
.def_prop_rw("dialect_search_modules",
3131
&PyGlobals::getDialectSearchPrefixes,
3232
&PyGlobals::setDialectSearchPrefixes)
33-
.def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
34-
"module_name"_a)
33+
.def(
34+
"append_dialect_search_prefix",
35+
[](PyGlobals &self, std::string moduleName) {
36+
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
37+
},
38+
"module_name"_a)
3539
.def(
3640
"_check_dialect_module_loaded",
3741
[](PyGlobals &self, const std::string &dialectNamespace) {
@@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) {
7276
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
7377
PyGlobals::get().registerOperationImpl(operationName, opClass,
7478
replace);
79+
7580
// Dict-stuff the new opClass by name onto the dialect class.
7681
nb::object opClassName = opClass.attr("__name__");
7782
dialectClass.attr(opClassName) = opClass;

mlir/python/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
22
numpy>=1.19.5, <=2.1.2
33
pybind11>=2.10.0, <=2.13.6
44
PyYAML>=5.4.0, <=6.0.1
5-
ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
5+
ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16

0 commit comments

Comments
 (0)