Skip to content

Commit 2de7155

Browse files
committed
fix races when initializing #[pyclass] type objects
1 parent 038ea4e commit 2de7155

File tree

2 files changed

+157
-138
lines changed

2 files changed

+157
-138
lines changed

src/impl_/pyclass/lazy_type_object.rs

Lines changed: 134 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
use std::{
2-
ffi::CStr,
32
marker::PhantomData,
43
thread::{self, ThreadId},
54
};
65

6+
use pyo3_ffi::PyTypeObject;
7+
78
#[cfg(Py_3_14)]
89
use crate::err::error_on_minusone;
9-
#[allow(deprecated)]
10-
use crate::sync::GILOnceCell;
1110
#[cfg(Py_3_14)]
1211
use crate::types::PyTypeMethods;
1312
use crate::{
1413
exceptions::PyRuntimeError,
15-
ffi,
1614
impl_::{pyclass::MaybeRuntimePyMethodDef, pymethods::PyMethodDefType},
1715
pyclass::{create_type_object, PyClassTypeObject},
18-
types::PyType,
19-
Bound, Py, PyAny, PyClass, PyErr, PyResult, Python,
16+
sync::PyOnceLock,
17+
type_object::PyTypeInfo,
18+
types::{PyAnyMethods, PyType},
19+
Bound, Py, PyClass, PyErr, PyResult, Python,
2020
};
2121

2222
use std::sync::Mutex;
@@ -29,13 +29,9 @@ pub struct LazyTypeObject<T>(LazyTypeObjectInner, PhantomData<T>);
2929

3030
// Non-generic inner of LazyTypeObject to keep code size down
3131
struct LazyTypeObjectInner {
32-
#[allow(deprecated)]
33-
value: GILOnceCell<PyClassTypeObject>,
34-
// Threads which have begun initialization of the `tp_dict`. Used for
35-
// reentrant initialization detection.
36-
initializing_threads: Mutex<Vec<ThreadId>>,
37-
#[allow(deprecated)]
38-
fully_initialized_type: GILOnceCell<Py<PyType>>,
32+
value: PyOnceLock<PyClassTypeObject>,
33+
initializing_thread: Mutex<Option<ThreadId>>,
34+
fully_initialized_type: PyOnceLock<Py<PyType>>,
3935
}
4036

4137
impl<T> LazyTypeObject<T> {
@@ -44,11 +40,9 @@ impl<T> LazyTypeObject<T> {
4440
pub const fn new() -> Self {
4541
LazyTypeObject(
4642
LazyTypeObjectInner {
47-
#[allow(deprecated)]
48-
value: GILOnceCell::new(),
49-
initializing_threads: Mutex::new(Vec::new()),
50-
#[allow(deprecated)]
51-
fully_initialized_type: GILOnceCell::new(),
43+
value: PyOnceLock::new(),
44+
initializing_thread: Mutex::new(None),
45+
fully_initialized_type: PyOnceLock::new(),
5246
},
5347
PhantomData,
5448
)
@@ -69,8 +63,13 @@ impl<T: PyClass> LazyTypeObject<T> {
6963

7064
#[cold]
7165
fn try_init<'py>(&self, py: Python<'py>) -> PyResult<&Bound<'py, PyType>> {
72-
self.0
73-
.get_or_try_init(py, create_type_object::<T>, T::NAME, T::items_iter())
66+
self.0.get_or_try_init(
67+
py,
68+
<T::BaseType as PyTypeInfo>::type_object_raw,
69+
create_type_object::<T>,
70+
T::NAME,
71+
T::items_iter(),
72+
)
7473
}
7574
}
7675

@@ -81,18 +80,28 @@ impl LazyTypeObjectInner {
8180
fn get_or_try_init<'py>(
8281
&self,
8382
py: Python<'py>,
83+
base_init: fn(Python<'py>) -> *mut PyTypeObject,
8484
init: fn(Python<'py>) -> PyResult<PyClassTypeObject>,
8585
name: &str,
8686
items_iter: PyClassItemsIter,
8787
) -> PyResult<&Bound<'py, PyType>> {
8888
(|| -> PyResult<_> {
89+
// ensure that base is fully initialized before entering the `PyOnceLock`
90+
// initialization; that could otherwise deadlock if the base type needs
91+
// to load the subtype as an attribute.
92+
//
93+
// don't try to synchronize this; assume that `base_init` handles concurrency and
94+
// re-entrancy in the same way this function does
95+
base_init(py);
96+
// at this point, we are guaranteed that the base type object has been created, we may be inside
97+
// `fill_tp_dict` of the base type in the case of this subtype being an attribute on the base
8998
let PyClassTypeObject {
9099
type_object,
91100
is_immutable_type,
92101
..
93102
} = self.value.get_or_try_init(py, || init(py))?;
94103
let type_object = type_object.bind(py);
95-
self.ensure_init(type_object, *is_immutable_type, name, items_iter)?;
104+
self.fill_tp_dict(type_object, *is_immutable_type, name, items_iter)?;
96105
Ok(type_object)
97106
})()
98107
.map_err(|err| {
@@ -104,154 +113,141 @@ impl LazyTypeObjectInner {
104113
})
105114
}
106115

107-
fn ensure_init(
116+
fn fill_tp_dict(
108117
&self,
109118
type_object: &Bound<'_, PyType>,
110119
#[allow(unused_variables)] is_immutable_type: bool,
111120
name: &str,
112121
items_iter: PyClassItemsIter,
113122
) -> PyResult<()> {
114-
let py = type_object.py();
123+
let py: Python<'_> = type_object.py();
115124

116125
// We might want to fill the `tp_dict` with python instances of `T`
117126
// itself. In order to do so, we must first initialize the type object
118127
// with an empty `tp_dict`: now we can create instances of `T`.
119128
//
120-
// Then we fill the `tp_dict`. Multiple threads may try to fill it at
121-
// the same time, but only one of them will succeed.
122-
//
123129
// More importantly, if a thread is performing initialization of the
124130
// `tp_dict`, it can still request the type object through `get_or_init`,
125131
// but the `tp_dict` may appear empty of course.
126132

127-
if self.fully_initialized_type.get(py).is_some() {
128-
// `tp_dict` is already filled: ok.
133+
let Some(guard) = InitializationGuard::new(&self.initializing_thread) else {
134+
// we are re-entrant with `tp_dict` initialization on this thread, we should
135+
// just return Ok and allow the init to proceed, whatever is accessing the type
136+
// object will just see the class without all attributes present.
129137
return Ok(());
130-
}
131-
132-
let thread_id = thread::current().id();
133-
{
134-
let mut threads = self.initializing_threads.lock().unwrap();
135-
if threads.contains(&thread_id) {
136-
// Reentrant call: just return the type object, even if the
137-
// `tp_dict` is not filled yet.
138-
return Ok(());
139-
}
140-
threads.push(thread_id);
141-
}
142-
143-
struct InitializationGuard<'a> {
144-
initializing_threads: &'a Mutex<Vec<ThreadId>>,
145-
thread_id: ThreadId,
146-
}
147-
impl Drop for InitializationGuard<'_> {
148-
fn drop(&mut self) {
149-
let mut threads = self.initializing_threads.lock().unwrap();
150-
threads.retain(|id| *id != self.thread_id);
151-
}
152-
}
153-
154-
let guard = InitializationGuard {
155-
initializing_threads: &self.initializing_threads,
156-
thread_id,
157138
};
158139

159-
// Pre-compute the class attribute objects: this can temporarily
160-
// release the GIL since we're calling into arbitrary user code. It
161-
// means that another thread can continue the initialization in the
162-
// meantime: at worst, we'll just make a useless computation.
163-
let mut items = vec![];
164-
for class_items in items_iter {
165-
for def in class_items.methods {
166-
let built_method;
167-
let method = match def {
168-
MaybeRuntimePyMethodDef::Runtime(builder) => {
169-
built_method = builder();
170-
&built_method
171-
}
172-
MaybeRuntimePyMethodDef::Static(method) => method,
173-
};
174-
if let PyMethodDefType::ClassAttribute(attr) = method {
175-
match (attr.meth)(py) {
176-
Ok(val) => items.push((attr.name, val)),
177-
Err(err) => {
178-
return Err(wrap_in_runtime_error(
179-
py,
180-
err,
181-
format!(
182-
"An error occurred while initializing `{}.{}`",
183-
name,
184-
attr.name.to_str().unwrap()
185-
),
186-
))
140+
// Only one thread will now proceed to set the type attributes.
141+
self.fully_initialized_type
142+
.get_or_try_init(py, move || -> PyResult<_> {
143+
guard.start_init();
144+
145+
for class_items in items_iter {
146+
for def in class_items.methods {
147+
let built_method;
148+
let method = match def {
149+
MaybeRuntimePyMethodDef::Runtime(builder) => {
150+
built_method = builder();
151+
&built_method
152+
}
153+
MaybeRuntimePyMethodDef::Static(method) => method,
154+
};
155+
if let PyMethodDefType::ClassAttribute(attr) = method {
156+
(attr.meth)(py)
157+
.and_then(|val| {
158+
type_object.setattr(
159+
// FIXME: add `IntoPyObject` for `&CStr`?
160+
attr.name.to_str().expect("attribute name should be UTF8"),
161+
val,
162+
)
163+
})
164+
.map_err(|err| {
165+
wrap_in_runtime_error(
166+
py,
167+
err,
168+
format!(
169+
"An error occurred while initializing `{}.{}`",
170+
name,
171+
attr.name.to_str().unwrap()
172+
),
173+
)
174+
})?;
187175
}
188176
}
189177
}
190-
}
191-
}
192178

193-
// Now we hold the GIL and we can assume it won't be released until we
194-
// return from the function.
195-
let result = self.fully_initialized_type.get_or_try_init(py, move || {
196-
initialize_tp_dict(py, type_object.as_ptr(), items)?;
197-
#[cfg(Py_3_14)]
198-
if is_immutable_type {
199-
// freeze immutable types after __dict__ is initialized
200-
let res = unsafe { ffi::PyType_Freeze(type_object.as_type_ptr()) };
201-
error_on_minusone(py, res)?;
202-
}
203-
#[cfg(all(Py_3_10, not(Py_LIMITED_API), not(Py_3_14)))]
204-
if is_immutable_type {
205-
use crate::types::PyTypeMethods as _;
206-
#[cfg(not(Py_GIL_DISABLED))]
207-
unsafe {
208-
(*type_object.as_type_ptr()).tp_flags |= ffi::Py_TPFLAGS_IMMUTABLETYPE
209-
};
210-
#[cfg(Py_GIL_DISABLED)]
211-
unsafe {
212-
(*type_object.as_type_ptr()).tp_flags.fetch_or(
213-
ffi::Py_TPFLAGS_IMMUTABLETYPE,
214-
std::sync::atomic::Ordering::Relaxed,
215-
)
216-
};
217-
unsafe { ffi::PyType_Modified(type_object.as_type_ptr()) };
218-
}
179+
#[cfg(Py_3_14)]
180+
if is_immutable_type {
181+
// freeze immutable types after __dict__ is initialized
182+
let res = unsafe { crate::ffi::PyType_Freeze(type_object.as_type_ptr()) };
183+
error_on_minusone(py, res)?;
184+
}
185+
#[cfg(all(Py_3_10, not(Py_LIMITED_API), not(Py_3_14)))]
186+
if is_immutable_type {
187+
use crate::types::PyTypeMethods as _;
188+
#[cfg(not(Py_GIL_DISABLED))]
189+
unsafe {
190+
(*type_object.as_type_ptr()).tp_flags |=
191+
crate::ffi::Py_TPFLAGS_IMMUTABLETYPE
192+
};
193+
#[cfg(Py_GIL_DISABLED)]
194+
unsafe {
195+
(*type_object.as_type_ptr()).tp_flags.fetch_or(
196+
crate::ffi::Py_TPFLAGS_IMMUTABLETYPE,
197+
std::sync::atomic::Ordering::Relaxed,
198+
)
199+
};
200+
unsafe { crate::ffi::PyType_Modified(type_object.as_type_ptr()) };
201+
}
219202

220-
// Initialization successfully complete, can clear the thread list.
221-
// (No further calls to get_or_init() will try to init, on any thread.)
222-
let mut threads = {
223203
drop(guard);
224-
self.initializing_threads.lock().unwrap()
225-
};
226-
threads.clear();
227-
Ok(type_object.clone().unbind())
228-
});
204+
Ok(type_object.clone().unbind())
205+
})?;
229206

230-
if let Err(err) = result {
231-
return Err(wrap_in_runtime_error(
232-
py,
233-
err,
234-
format!("An error occurred while initializing `{name}.__dict__`"),
235-
));
207+
Ok(())
208+
}
209+
}
210+
211+
struct InitializationGuard<'a> {
212+
initializing_thread: &'a Mutex<Option<ThreadId>>,
213+
thread_id: ThreadId,
214+
}
215+
216+
impl<'a> InitializationGuard<'a> {
217+
/// Attempt to create a new `InitializationGuard`.
218+
///
219+
/// Returns `None` if this call would be re-entrant.
220+
///
221+
/// The guard will not protect against re-entrancy until `start_init` is called.
222+
fn new(initializing_thread: &'a Mutex<Option<ThreadId>>) -> Option<Self> {
223+
let thread_id = thread::current().id();
224+
let thread = initializing_thread.lock().expect("no poisoning");
225+
if thread.is_some_and(|id| id == thread_id) {
226+
None
227+
} else {
228+
Some(Self {
229+
initializing_thread,
230+
thread_id,
231+
})
236232
}
233+
}
237234

238-
Ok(())
235+
/// Starts the initialization process. From this point forward `InitializationGuard::new` will protect against re-entrancy.
236+
fn start_init(&self) {
237+
let mut thread = self.initializing_thread.lock().expect("no poisoning");
238+
assert!(thread.is_none(), "overlapping use of `InitializationGuard`");
239+
*thread = Some(self.thread_id);
239240
}
240241
}
241242

242-
fn initialize_tp_dict(
243-
py: Python<'_>,
244-
type_object: *mut ffi::PyObject,
245-
items: Vec<(&'static CStr, Py<PyAny>)>,
246-
) -> PyResult<()> {
247-
// We hold the GIL: the dictionary update can be considered atomic from
248-
// the POV of other threads.
249-
for (key, val) in items {
250-
crate::err::error_on_minusone(py, unsafe {
251-
ffi::PyObject_SetAttrString(type_object, key.as_ptr(), val.into_ptr())
252-
})?;
243+
impl Drop for InitializationGuard<'_> {
244+
fn drop(&mut self) {
245+
let mut thread = self.initializing_thread.lock().unwrap();
246+
// only clear the thread if this was the thread which called `start_init`
247+
if thread.is_some_and(|id| id == self.thread_id) {
248+
*thread = None;
249+
}
253250
}
254-
Ok(())
255251
}
256252

257253
// This is necessary for making static `LazyTypeObject`s

tests/test_class_attributes.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,26 @@ test_case!(
319319
"FIELDONE",
320320
test_rename_all_uppercase
321321
);
322+
323+
#[test]
324+
fn test_class_attribute_reentrancy() {
325+
#[pyclass(subclass)]
326+
struct Base;
327+
328+
#[pymethods]
329+
impl Base {
330+
#[classattr]
331+
#[allow(non_snake_case)]
332+
fn DERIVED(py: Python<'_>) -> Py<Derived> {
333+
Py::new(py, (Derived, Base)).unwrap()
334+
}
335+
}
336+
337+
#[pyclass(extends = Base)]
338+
struct Derived;
339+
340+
Python::attach(|py| {
341+
let derived_class = py.get_type::<Derived>();
342+
assert!(derived_class.getattr("DERIVED").is_ok());
343+
})
344+
}

0 commit comments

Comments
 (0)