@@ -192,7 +192,9 @@ py::object classmethod(Func f, Args... args) {
192
192
static py::object
193
193
createCustomDialectWrapper (const std::string &dialectNamespace,
194
194
py::object dialectDescriptor) {
195
- auto dialectClass = PyGlobals::get ().lookupDialectClass (dialectNamespace);
195
+ auto dialectClass = PyGlobals::withInstance ([&](PyGlobals& instance) {
196
+ return instance.lookupDialectClass (dialectNamespace);
197
+ });
196
198
if (!dialectClass) {
197
199
// Use the base class.
198
200
return py::cast (PyDialect (std::move (dialectDescriptor)));
@@ -595,16 +597,23 @@ class PyOpOperandIterator {
595
597
596
598
PyMlirContext::PyMlirContext (MlirContext context) : context(context) {
597
599
py::gil_scoped_acquire acquire;
598
- auto &liveContexts = getLiveContexts ();
599
- liveContexts[context.ptr ] = this ;
600
+ withLiveContexts ([&](LiveContextMap& liveContexts) {
601
+ liveContexts[context.ptr ] = this ;
602
+ return this ;
603
+ });
600
604
}
601
605
602
606
PyMlirContext::~PyMlirContext () {
603
607
// Note that the only public way to construct an instance is via the
604
608
// forContext method, which always puts the associated handle into
605
609
// liveContexts.
606
610
py::gil_scoped_acquire acquire;
607
- getLiveContexts ().erase (context.ptr );
611
+
612
+ withLiveContexts ([&](LiveContextMap& liveContexts) {
613
+ liveContexts.erase (context.ptr );
614
+ return this ;
615
+ });
616
+
608
617
mlirContextDestroy (context);
609
618
}
610
619
@@ -626,27 +635,32 @@ PyMlirContext *PyMlirContext::createNewContextForInit() {
626
635
627
636
PyMlirContextRef PyMlirContext::forContext (MlirContext context) {
628
637
py::gil_scoped_acquire acquire;
629
- auto &liveContexts = getLiveContexts ();
630
- auto it = liveContexts.find (context.ptr );
631
- if (it == liveContexts.end ()) {
632
- // Create.
633
- PyMlirContext *unownedContextWrapper = new PyMlirContext (context);
634
- py::object pyRef = py::cast (unownedContextWrapper);
635
- assert (pyRef && " cast to py::object failed" );
636
- liveContexts[context.ptr ] = unownedContextWrapper;
637
- return PyMlirContextRef (unownedContextWrapper, std::move (pyRef));
638
- }
639
- // Use existing.
640
- py::object pyRef = py::cast (it->second );
641
- return PyMlirContextRef (it->second , std::move (pyRef));
638
+ return withLiveContexts ([&](LiveContextMap& liveContexts) {
639
+ auto it = liveContexts.find (context.ptr );
640
+ if (it == liveContexts.end ()) {
641
+ // Create.
642
+ PyMlirContext *unownedContextWrapper = new PyMlirContext (context);
643
+ py::object pyRef = py::cast (unownedContextWrapper);
644
+ assert (pyRef && " cast to py::object failed" );
645
+ liveContexts[context.ptr ] = unownedContextWrapper;
646
+ return PyMlirContextRef (unownedContextWrapper, std::move (pyRef));
647
+ }
648
+ // Use existing.
649
+ py::object pyRef = py::cast (it->second );
650
+ return PyMlirContextRef (it->second , std::move (pyRef));
651
+ });
642
652
}
643
653
644
654
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts () {
645
655
static LiveContextMap liveContexts;
646
656
return liveContexts;
647
657
}
648
658
649
- size_t PyMlirContext::getLiveCount () { return getLiveContexts ().size (); }
659
+ size_t PyMlirContext::getLiveCount () {
660
+ return withLiveContexts ([&](LiveContextMap& liveContexts) {
661
+ return liveContexts.size ();
662
+ });
663
+ }
650
664
651
665
size_t PyMlirContext::getLiveOperationCount () { return liveOperations.size (); }
652
666
@@ -1550,8 +1564,10 @@ py::object PyOperation::createOpView() {
1550
1564
checkValid ();
1551
1565
MlirIdentifier ident = mlirOperationGetName (get ());
1552
1566
MlirStringRef identStr = mlirIdentifierStr (ident);
1553
- auto operationCls = PyGlobals::get ().lookupOperationClass (
1554
- StringRef (identStr.data , identStr.length ));
1567
+ auto operationCls = PyGlobals::withInstance ([&](PyGlobals& instance){
1568
+ return instance.lookupOperationClass (
1569
+ StringRef (identStr.data , identStr.length ));
1570
+ });
1555
1571
if (operationCls)
1556
1572
return PyOpView::constructDerived (*operationCls, *getRef ().get ());
1557
1573
return py::cast (PyOpView (getRef ().getObject ()));
@@ -2002,7 +2018,9 @@ pybind11::object PyValue::maybeDownCast() {
2002
2018
assert (!mlirTypeIDIsNull (mlirTypeID) &&
2003
2019
" mlirTypeID was expected to be non-null." );
2004
2020
std::optional<pybind11::function> valueCaster =
2005
- PyGlobals::get ().lookupValueCaster (mlirTypeID, mlirTypeGetDialect (type));
2021
+ PyGlobals::withInstance ([&](PyGlobals& instance) {
2022
+ return instance.lookupValueCaster (mlirTypeID, mlirTypeGetDialect (type));
2023
+ });
2006
2024
// py::return_value_policy::move means use std::move to move the return value
2007
2025
// contents into a new instance that will be owned by Python.
2008
2026
py::object thisObj = py::cast (this , py::return_value_policy::move);
@@ -3481,8 +3499,10 @@ void mlir::python::populateIRCore(py::module &m) {
3481
3499
assert (!mlirTypeIDIsNull (mlirTypeID) &&
3482
3500
" mlirTypeID was expected to be non-null." );
3483
3501
std::optional<pybind11::function> typeCaster =
3484
- PyGlobals::get ().lookupTypeCaster (mlirTypeID,
3485
- mlirAttributeGetDialect (self));
3502
+ PyGlobals::withInstance ([&](PyGlobals& instance){
3503
+ return instance.lookupTypeCaster (mlirTypeID,
3504
+ mlirAttributeGetDialect (self));
3505
+ });
3486
3506
if (!typeCaster)
3487
3507
return py::cast (self);
3488
3508
return typeCaster.value ()(self);
@@ -3579,9 +3599,11 @@ void mlir::python::populateIRCore(py::module &m) {
3579
3599
MlirTypeID mlirTypeID = mlirTypeGetTypeID (self);
3580
3600
assert (!mlirTypeIDIsNull (mlirTypeID) &&
3581
3601
" mlirTypeID was expected to be non-null." );
3582
- std::optional<pybind11::function> typeCaster =
3583
- PyGlobals::get ().lookupTypeCaster (mlirTypeID,
3602
+ std::optional<pybind11::function> typeCaster =
3603
+ PyGlobals::withInstance ([&](PyGlobals& instance){
3604
+ return instance.lookupTypeCaster (mlirTypeID,
3584
3605
mlirTypeGetDialect (self));
3606
+ });
3585
3607
if (!typeCaster)
3586
3608
return py::cast (self);
3587
3609
return typeCaster.value ()(self);
0 commit comments