Skip to content

Commit cdd2b03

Browse files
committed
[skip-ci] Added lock on PyGlobals::get and PyMlirContext liveContexts
WIP on adding multithreaded_tests
1 parent a5ed9d0 commit cdd2b03

File tree

8 files changed

+251
-30
lines changed

8 files changed

+251
-30
lines changed

mlir/docs/Bindings/Python.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView):
10351035
...
10361036
```
10371037

1038-
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
1038+
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
10391039
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
10401040

10411041
```python
@@ -1179,7 +1179,7 @@ make the passes available along with the dialect.
11791179
Dialect functionality other than IR objects or passes, such as helper functions,
11801180
can be exposed to Python similarly to attributes and types. C API is expected to
11811181
exist for this functionality, which can then be wrapped using pybind11 and
1182-
`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`
1182+
[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)
11831183
utilities to connect to the rest of Python API. The bindings can be located in a
11841184
separate pybind11 module or in the same module as attributes and types, and
11851185
loaded along with the dialect.

mlir/examples/standalone/python/StandaloneExtension.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include "Standalone-c/Dialects.h"
1010
#include "mlir/Bindings/Python/PybindAdaptors.h"
1111

12+
namespace py = pybind11;
13+
1214
using namespace mlir::python::adaptors;
1315

14-
PYBIND11_MODULE(_standaloneDialects, m) {
16+
PYBIND11_MODULE(_standaloneDialects, m, py::mod_gil_not_used()) {
1517
//===--------------------------------------------------------------------===//
1618
// standalone dialect
1719
//===--------------------------------------------------------------------===//

mlir/lib/Bindings/Python/Globals.h

+22
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ class PyGlobals {
3636
return *instance;
3737
}
3838

39+
template<typename F>
40+
static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
41+
auto &instance = get();
42+
#ifdef Py_GIL_DISABLED
43+
auto &lock = getLock();
44+
PyMutex_Lock(&lock);
45+
#endif
46+
auto result = cb(instance);
47+
#ifdef Py_GIL_DISABLED
48+
PyMutex_Unlock(&lock);
49+
#endif
50+
return result;
51+
}
52+
3953
/// Get and set the list of parent modules to search for dialect
4054
/// implementation classes.
4155
std::vector<std::string> &getDialectSearchPrefixes() {
@@ -125,6 +139,14 @@ class PyGlobals {
125139
/// Set of dialect namespaces that we have attempted to import implementation
126140
/// modules for.
127141
llvm::StringSet<> loadedDialectModules;
142+
143+
#ifdef Py_GIL_DISABLED
144+
static PyMutex &getLock() {
145+
static PyMutex lock;
146+
return lock;
147+
}
148+
#endif
149+
128150
};
129151

130152
} // namespace python

mlir/lib/Bindings/Python/IRCore.cpp

+47-25
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ py::object classmethod(Func f, Args... args) {
192192
static py::object
193193
createCustomDialectWrapper(const std::string &dialectNamespace,
194194
py::object dialectDescriptor) {
195-
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
195+
auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
196+
return instance.lookupDialectClass(dialectNamespace);
197+
});
196198
if (!dialectClass) {
197199
// Use the base class.
198200
return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -595,16 +597,23 @@ class PyOpOperandIterator {
595597

596598
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
597599
py::gil_scoped_acquire acquire;
598-
auto &liveContexts = getLiveContexts();
599-
liveContexts[context.ptr] = this;
600+
withLiveContexts([&](LiveContextMap& liveContexts) {
601+
liveContexts[context.ptr] = this;
602+
return this;
603+
});
600604
}
601605

602606
PyMlirContext::~PyMlirContext() {
603607
// Note that the only public way to construct an instance is via the
604608
// forContext method, which always puts the associated handle into
605609
// liveContexts.
606610
py::gil_scoped_acquire acquire;
607-
getLiveContexts().erase(context.ptr);
611+
612+
withLiveContexts([&](LiveContextMap& liveContexts) {
613+
liveContexts.erase(context.ptr);
614+
return this;
615+
});
616+
608617
mlirContextDestroy(context);
609618
}
610619

@@ -626,27 +635,32 @@ PyMlirContext *PyMlirContext::createNewContextForInit() {
626635

627636
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
628637
py::gil_scoped_acquire acquire;
629-
auto &liveContexts = getLiveContexts();
630-
auto it = liveContexts.find(context.ptr);
631-
if (it == liveContexts.end()) {
632-
// Create.
633-
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
634-
py::object pyRef = py::cast(unownedContextWrapper);
635-
assert(pyRef && "cast to py::object failed");
636-
liveContexts[context.ptr] = unownedContextWrapper;
637-
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
638-
}
639-
// Use existing.
640-
py::object pyRef = py::cast(it->second);
641-
return PyMlirContextRef(it->second, std::move(pyRef));
638+
return withLiveContexts([&](LiveContextMap& liveContexts) {
639+
auto it = liveContexts.find(context.ptr);
640+
if (it == liveContexts.end()) {
641+
// Create.
642+
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
643+
py::object pyRef = py::cast(unownedContextWrapper);
644+
assert(pyRef && "cast to py::object failed");
645+
liveContexts[context.ptr] = unownedContextWrapper;
646+
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
647+
}
648+
// Use existing.
649+
py::object pyRef = py::cast(it->second);
650+
return PyMlirContextRef(it->second, std::move(pyRef));
651+
});
642652
}
643653

644654
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
645655
static LiveContextMap liveContexts;
646656
return liveContexts;
647657
}
648658

649-
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
659+
size_t PyMlirContext::getLiveCount() {
660+
return withLiveContexts([&](LiveContextMap& liveContexts) {
661+
return liveContexts.size();
662+
});
663+
}
650664

651665
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
652666

@@ -1550,8 +1564,10 @@ py::object PyOperation::createOpView() {
15501564
checkValid();
15511565
MlirIdentifier ident = mlirOperationGetName(get());
15521566
MlirStringRef identStr = mlirIdentifierStr(ident);
1553-
auto operationCls = PyGlobals::get().lookupOperationClass(
1554-
StringRef(identStr.data, identStr.length));
1567+
auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
1568+
return instance.lookupOperationClass(
1569+
StringRef(identStr.data, identStr.length));
1570+
});
15551571
if (operationCls)
15561572
return PyOpView::constructDerived(*operationCls, *getRef().get());
15571573
return py::cast(PyOpView(getRef().getObject()));
@@ -2002,7 +2018,9 @@ pybind11::object PyValue::maybeDownCast() {
20022018
assert(!mlirTypeIDIsNull(mlirTypeID) &&
20032019
"mlirTypeID was expected to be non-null.");
20042020
std::optional<pybind11::function> valueCaster =
2005-
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2021+
PyGlobals::withInstance([&](PyGlobals& instance) {
2022+
return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2023+
});
20062024
// py::return_value_policy::move means use std::move to move the return value
20072025
// contents into a new instance that will be owned by Python.
20082026
py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3481,8 +3499,10 @@ void mlir::python::populateIRCore(py::module &m) {
34813499
assert(!mlirTypeIDIsNull(mlirTypeID) &&
34823500
"mlirTypeID was expected to be non-null.");
34833501
std::optional<pybind11::function> typeCaster =
3484-
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3485-
mlirAttributeGetDialect(self));
3502+
PyGlobals::withInstance([&](PyGlobals& instance){
3503+
return instance.lookupTypeCaster(mlirTypeID,
3504+
mlirAttributeGetDialect(self));
3505+
});
34863506
if (!typeCaster)
34873507
return py::cast(self);
34883508
return typeCaster.value()(self);
@@ -3579,9 +3599,11 @@ void mlir::python::populateIRCore(py::module &m) {
35793599
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
35803600
assert(!mlirTypeIDIsNull(mlirTypeID) &&
35813601
"mlirTypeID was expected to be non-null.");
3582-
std::optional<pybind11::function> typeCaster =
3583-
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3602+
std::optional<pybind11::function> typeCaster =
3603+
PyGlobals::withInstance([&](PyGlobals& instance){
3604+
return instance.lookupTypeCaster(mlirTypeID,
35843605
mlirTypeGetDialect(self));
3606+
});
35853607
if (!typeCaster)
35863608
return py::cast(self);
35873609
return typeCaster.value()(self);

mlir/lib/Bindings/Python/IRModule.h

+21
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,27 @@ class PyMlirContext {
263263
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
264264
static LiveContextMap &getLiveContexts();
265265

266+
#ifdef Py_GIL_DISABLED
267+
static PyMutex &getLock() {
268+
static PyMutex lock;
269+
return lock;
270+
}
271+
#endif
272+
273+
template<typename F>
274+
static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
275+
auto &liveContexts = getLiveContexts();
276+
#ifdef Py_GIL_DISABLED
277+
auto &lock = getLock();
278+
PyMutex_Lock(&lock);
279+
#endif
280+
auto result = cb(liveContexts);
281+
#ifdef Py_GIL_DISABLED
282+
PyMutex_Unlock(&lock);
283+
#endif
284+
return result;
285+
}
286+
266287
// Interns all live modules associated with this context. Modules tracked
267288
// in this map are valid. When a module is invalidated, it is removed
268289
// from this map, and while it still exists as an instance, any

mlir/test/python/execution_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def callback(a):
306306
log(arr)
307307

308308
with Context():
309-
# The module takes a subview of the argument memref, casts it to an unranked memref and
309+
# The module takes a subview of the argument memref, casts it to an unranked memref and
310310
# calls the callback with it.
311311
module = Module.parse(
312312
r"""

mlir/test/python/lib/PythonTestModule.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
2121
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
2222
}
2323

24-
PYBIND11_MODULE(_mlirPythonTest, m) {
24+
PYBIND11_MODULE(_mlirPythonTest, m, py::mod_gil_not_used()) {
2525
m.def(
2626
"register_python_test_dialect",
2727
[](MlirContext context, bool load) {

0 commit comments

Comments
 (0)