Skip to content

Commit f9d1e1a

Browse files
committed
fix memory leak with recursive definitions creating reference cycles
1 parent bec63db commit f9d1e1a

File tree

3 files changed

+100
-92
lines changed

3 files changed

+100
-92
lines changed

src/definitions.rs

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::{
88
fmt::Debug,
99
sync::{
1010
atomic::{AtomicBool, Ordering},
11-
Arc, OnceLock,
11+
Arc, OnceLock, Weak,
1212
},
1313
};
1414

@@ -28,47 +28,50 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
2828
/// They get indexed by a ReferenceId, which are integer identifiers
2929
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
3030
/// gets build.
31-
#[derive(Clone)]
3231
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);
3332

34-
/// Internal type which contains a definition to be filled
35-
pub struct Definition<T>(Arc<DefinitionInner<T>>);
36-
37-
struct DefinitionInner<T> {
38-
value: OnceLock<T>,
39-
name: LazyName,
33+
struct Definition<T> {
34+
value: Arc<OnceLock<T>>,
35+
name: Arc<LazyName>,
4036
}
4137

4238
/// Reference to a definition.
4339
pub struct DefinitionRef<T> {
44-
name: Arc<String>,
45-
value: Definition<T>,
40+
reference: Arc<String>,
41+
// We use a weak reference to the definition to avoid a reference cycle
42+
// when recursive definitions are used.
43+
value: Weak<OnceLock<T>>,
44+
name: Arc<LazyName>,
4645
}
4746

4847
// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
4948
impl<T> Clone for DefinitionRef<T> {
5049
fn clone(&self) -> Self {
5150
Self {
52-
name: self.name.clone(),
51+
reference: self.reference.clone(),
5352
value: self.value.clone(),
53+
name: self.name.clone(),
5454
}
5555
}
5656
}
5757

5858
impl<T> DefinitionRef<T> {
5959
pub fn id(&self) -> usize {
60-
Arc::as_ptr(&self.value.0) as usize
60+
Weak::as_ptr(&self.value) as usize
6161
}
6262

6363
pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
64-
match self.value.0.value.get() {
65-
Some(value) => self.value.0.name.get_or_init(|| init(value)),
64+
let Some(definition) = self.value.upgrade() else {
65+
return "...";
66+
};
67+
match definition.get() {
68+
Some(value) => self.name.get_or_init(|| init(value)),
6669
None => "...",
6770
}
6871
}
6972

70-
pub fn get(&self) -> Option<&T> {
71-
self.value.0.value.get()
73+
pub fn read<R>(&self, f: impl FnOnce(Option<&T>) -> R) -> R {
74+
f(self.value.upgrade().as_ref().and_then(|value| value.get()))
7275
}
7376
}
7477

@@ -96,15 +99,9 @@ impl<T: Debug> Debug for Definitions<T> {
9699
}
97100
}
98101

99-
impl<T> Clone for Definition<T> {
100-
fn clone(&self) -> Self {
101-
Self(self.0.clone())
102-
}
103-
}
104-
105102
impl<T: Debug> Debug for Definition<T> {
106103
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107-
match self.0.value.get() {
104+
match self.value.get() {
108105
Some(value) => value.fmt(f),
109106
None => "...".fmt(f),
110107
}
@@ -113,7 +110,7 @@ impl<T: Debug> Debug for Definition<T> {
113110

114111
impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
115112
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
116-
if let Some(value) = self.value.0.value.get() {
113+
if let Some(value) = self.value.upgrade().as_ref().and_then(|v| v.get()) {
117114
value.py_gc_traverse(visit)?;
118115
}
119116
Ok(())
@@ -123,15 +120,15 @@ impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
123120
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
124121
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
125122
for value in self.0.values() {
126-
if let Some(value) = value.0.value.get() {
123+
if let Some(value) = value.value.get() {
127124
value.py_gc_traverse(visit)?;
128125
}
129126
}
130127
Ok(())
131128
}
132129
}
133130

134-
#[derive(Clone, Debug)]
131+
#[derive(Debug)]
135132
pub struct DefinitionsBuilder<T> {
136133
definitions: Definitions<T>,
137134
}
@@ -148,45 +145,48 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
148145
// We either need a String copy or two hashmap lookups
149146
// Neither is better than the other
150147
// We opted for the easier outward facing API
151-
let name = Arc::new(reference.to_string());
152-
let value = match self.definitions.0.entry(name.clone()) {
148+
let reference = Arc::new(reference.to_string());
149+
let value = match self.definitions.0.entry(reference.clone()) {
153150
Entry::Occupied(entry) => entry.into_mut(),
154-
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
155-
value: OnceLock::new(),
156-
name: LazyName::new(),
157-
}))),
151+
Entry::Vacant(entry) => entry.insert(Definition {
152+
value: Arc::new(OnceLock::new()),
153+
name: Arc::new(LazyName::new()),
154+
}),
158155
};
159156
DefinitionRef {
160-
name,
161-
value: value.clone(),
157+
reference,
158+
value: Arc::downgrade(&value.value),
159+
name: value.name.clone(),
162160
}
163161
}
164162

165163
/// Add a definition, returning the ReferenceId that maps to it
166164
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
167-
let name = Arc::new(reference);
168-
let value = match self.definitions.0.entry(name.clone()) {
165+
let reference = Arc::new(reference);
166+
let value = match self.definitions.0.entry(reference.clone()) {
169167
Entry::Occupied(entry) => {
170168
let definition = entry.into_mut();
171-
match definition.0.value.set(value) {
172-
Ok(()) => definition.clone(),
173-
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
169+
match definition.value.set(value) {
170+
Ok(()) => definition,
171+
Err(_) => return py_schema_err!("Duplicate ref: `{}`", reference),
174172
}
175173
}
176-
Entry::Vacant(entry) => entry
177-
.insert(Definition(Arc::new(DefinitionInner {
178-
value: OnceLock::from(value),
179-
name: LazyName::new(),
180-
})))
181-
.clone(),
174+
Entry::Vacant(entry) => entry.insert(Definition {
175+
value: Arc::new(OnceLock::from(value)),
176+
name: Arc::new(LazyName::new()),
177+
}),
182178
};
183-
Ok(DefinitionRef { name, value })
179+
Ok(DefinitionRef {
180+
reference,
181+
value: Arc::downgrade(&value.value),
182+
name: value.name.clone(),
183+
})
184184
}
185185

186186
/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
187187
pub fn finish(self) -> PyResult<Definitions<T>> {
188188
for (reference, def) in &self.definitions.0 {
189-
if def.0.value.get().is_none() {
189+
if def.value.get().is_none() {
190190
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
191191
}
192192
}

src/serializers/type_serializers/definitions.rs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,17 @@ impl TypeSerializer for DefinitionRefSerializer {
6868
exclude: Option<&PyAny>,
6969
extra: &Extra,
7070
) -> PyResult<PyObject> {
71-
let comb_serializer = self.definition.get().unwrap();
72-
let value_id = extra.rec_guard.add(value, self.definition.id())?;
73-
let r = comb_serializer.to_python(value, include, exclude, extra);
74-
extra.rec_guard.pop(value_id, self.definition.id());
75-
r
71+
self.definition.read(|comb_serializer| {
72+
let comb_serializer = comb_serializer.unwrap();
73+
let value_id = extra.rec_guard.add(value, self.definition.id())?;
74+
let r = comb_serializer.to_python(value, include, exclude, extra);
75+
extra.rec_guard.pop(value_id, self.definition.id());
76+
r
77+
})
7678
}
7779

7880
fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
79-
self.definition.get().unwrap().json_key(key, extra)
81+
self.definition.read(|s| s.unwrap().json_key(key, extra))
8082
}
8183

8284
fn serde_serialize<S: serde::ser::Serializer>(
@@ -87,21 +89,23 @@ impl TypeSerializer for DefinitionRefSerializer {
8789
exclude: Option<&PyAny>,
8890
extra: &Extra,
8991
) -> Result<S::Ok, S::Error> {
90-
let comb_serializer = self.definition.get().unwrap();
91-
let value_id = extra
92-
.rec_guard
93-
.add(value, self.definition.id())
94-
.map_err(py_err_se_err)?;
95-
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
96-
extra.rec_guard.pop(value_id, self.definition.id());
97-
r
92+
self.definition.read(|comb_serializer| {
93+
let comb_serializer = comb_serializer.unwrap();
94+
let value_id = extra
95+
.rec_guard
96+
.add(value, self.definition.id())
97+
.map_err(py_err_se_err)?;
98+
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
99+
extra.rec_guard.pop(value_id, self.definition.id());
100+
r
101+
})
98102
}
99103

100104
fn get_name(&self) -> &str {
101105
Self::EXPECTED_TYPE
102106
}
103107

104108
fn retry_with_lax_check(&self) -> bool {
105-
self.definition.get().unwrap().retry_with_lax_check()
109+
self.definition.read(|s| s.unwrap().retry_with_lax_check())
106110
}
107111
}

src/validators/definitions.rs

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,25 @@ impl Validator for DefinitionRefValidator {
7373
input: &'data impl Input<'data>,
7474
state: &mut ValidationState,
7575
) -> ValResult<PyObject> {
76-
let validator = self.definition.get().unwrap();
77-
if let Some(id) = input.identity() {
78-
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
79-
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
80-
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
81-
} else {
82-
if state.recursion_guard.incr_depth() {
83-
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
76+
self.definition.read(|validator| {
77+
let validator = validator.unwrap();
78+
if let Some(id) = input.identity() {
79+
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
80+
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
81+
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
82+
} else {
83+
if state.recursion_guard.incr_depth() {
84+
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
85+
}
86+
let output = validator.validate(py, input, state);
87+
state.recursion_guard.remove(id, self.definition.id());
88+
state.recursion_guard.decr_depth();
89+
output
8490
}
85-
let output = validator.validate(py, input, state);
86-
state.recursion_guard.remove(id, self.definition.id());
87-
state.recursion_guard.decr_depth();
88-
output
91+
} else {
92+
validator.validate(py, input, state)
8993
}
90-
} else {
91-
validator.validate(py, input, state)
92-
}
94+
})
9395
}
9496

9597
fn validate_assignment<'data>(
@@ -100,23 +102,25 @@ impl Validator for DefinitionRefValidator {
100102
field_value: &'data PyAny,
101103
state: &mut ValidationState,
102104
) -> ValResult<PyObject> {
103-
let validator = self.definition.get().unwrap();
104-
if let Some(id) = obj.identity() {
105-
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
106-
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
107-
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
108-
} else {
109-
if state.recursion_guard.incr_depth() {
110-
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
105+
self.definition.read(|validator| {
106+
let validator = validator.unwrap();
107+
if let Some(id) = obj.identity() {
108+
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
109+
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
110+
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
111+
} else {
112+
if state.recursion_guard.incr_depth() {
113+
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
114+
}
115+
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
116+
state.recursion_guard.remove(id, self.definition.id());
117+
state.recursion_guard.decr_depth();
118+
output
111119
}
112-
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
113-
state.recursion_guard.remove(id, self.definition.id());
114-
state.recursion_guard.decr_depth();
115-
output
120+
} else {
121+
validator.validate_assignment(py, obj, field_name, field_value, state)
116122
}
117-
} else {
118-
validator.validate_assignment(py, obj, field_name, field_value, state)
119-
}
123+
})
120124
}
121125

122126
fn get_name(&self) -> &str {

0 commit comments

Comments
 (0)