Skip to content

Commit 986375e

Browse files
bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (#4080)
with the persistent_id() and persistent_load() methods.
1 parent bc8ac6b commit 986375e

File tree

3 files changed

+164
-40
lines changed

3 files changed

+164
-40
lines changed

Lib/test/test_pickle.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import collections
77
import struct
88
import sys
9+
import weakref
910

1011
import unittest
1112
from test import support
@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
117118
pickler = pickle._Pickler
118119
unpickler = pickle._Unpickler
119120

121+
@support.cpython_only
122+
def test_pickler_reference_cycle(self):
123+
def check(Pickler):
124+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
125+
f = io.BytesIO()
126+
pickler = Pickler(f, proto)
127+
pickler.dump('abc')
128+
self.assertEqual(self.loads(f.getvalue()), 'abc')
129+
pickler = Pickler(io.BytesIO())
130+
self.assertEqual(pickler.persistent_id('def'), 'def')
131+
r = weakref.ref(pickler)
132+
del pickler
133+
self.assertIsNone(r())
134+
135+
class PersPickler(self.pickler):
136+
def persistent_id(subself, obj):
137+
return obj
138+
check(PersPickler)
139+
140+
class PersPickler(self.pickler):
141+
@classmethod
142+
def persistent_id(cls, obj):
143+
return obj
144+
check(PersPickler)
145+
146+
class PersPickler(self.pickler):
147+
@staticmethod
148+
def persistent_id(obj):
149+
return obj
150+
check(PersPickler)
151+
152+
@support.cpython_only
153+
def test_unpickler_reference_cycle(self):
154+
def check(Unpickler):
155+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
156+
unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
157+
self.assertEqual(unpickler.load(), 'abc')
158+
unpickler = Unpickler(io.BytesIO())
159+
self.assertEqual(unpickler.persistent_load('def'), 'def')
160+
r = weakref.ref(unpickler)
161+
del unpickler
162+
self.assertIsNone(r())
163+
164+
class PersUnpickler(self.unpickler):
165+
def persistent_load(subself, pid):
166+
return pid
167+
check(PersUnpickler)
168+
169+
class PersUnpickler(self.unpickler):
170+
@classmethod
171+
def persistent_load(cls, pid):
172+
return pid
173+
check(PersUnpickler)
174+
175+
class PersUnpickler(self.unpickler):
176+
@staticmethod
177+
def persistent_load(pid):
178+
return pid
179+
check(PersUnpickler)
180+
120181

121182
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
122183

@@ -197,7 +258,7 @@ class SizeofTests(unittest.TestCase):
197258
check_sizeof = support.check_sizeof
198259

199260
def test_pickler(self):
200-
basesize = support.calcobjsize('5P2n3i2n3iP')
261+
basesize = support.calcobjsize('6P2n3i2n3iP')
201262
p = _pickle.Pickler(io.BytesIO())
202263
self.assertEqual(object.__sizeof__(p), basesize)
203264
MT_size = struct.calcsize('3nP0n')
@@ -214,7 +275,7 @@ def test_pickler(self):
214275
0) # Write buffer is cleared after every dump().
215276

216277
def test_unpickler(self):
217-
basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i')
278+
basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
218279
unpickler = _pickle.Unpickler
219280
P = struct.calcsize('P') # Size of memo table entry.
220281
n = struct.calcsize('n') # Size of mark table entry.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Instances of pickle.Pickler subclass with the persistent_id() method and
2+
pickle.Unpickler subclass with the persistent_load() method no longer create
3+
reference cycles.

Modules/_pickle.c

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj)
360360

361361
/*************************************************************************/
362362

363+
/* Retrieve and deconstruct a method for avoiding a reference cycle
364+
(pickler -> bound method of pickler -> pickler) */
365+
static int
366+
init_method_ref(PyObject *self, _Py_Identifier *name,
367+
PyObject **method_func, PyObject **method_self)
368+
{
369+
PyObject *func, *func2;
370+
371+
/* *method_func and *method_self should be consistent. All refcount decrements
372+
should be occurred after setting *method_self and *method_func. */
373+
func = _PyObject_GetAttrId(self, name);
374+
if (func == NULL) {
375+
*method_self = NULL;
376+
Py_CLEAR(*method_func);
377+
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
378+
return -1;
379+
}
380+
PyErr_Clear();
381+
return 0;
382+
}
383+
384+
if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) {
385+
/* Deconstruct a bound Python method */
386+
func2 = PyMethod_GET_FUNCTION(func);
387+
Py_INCREF(func2);
388+
*method_self = self; /* borrowed */
389+
Py_XSETREF(*method_func, func2);
390+
Py_DECREF(func);
391+
return 0;
392+
}
393+
else {
394+
*method_self = NULL;
395+
Py_XSETREF(*method_func, func);
396+
return 0;
397+
}
398+
}
399+
400+
/* Bind a method if it was deconstructed */
401+
static PyObject *
402+
reconstruct_method(PyObject *func, PyObject *self)
403+
{
404+
if (self) {
405+
return PyMethod_New(func, self);
406+
}
407+
else {
408+
Py_INCREF(func);
409+
return func;
410+
}
411+
}
412+
413+
static PyObject *
414+
call_method(PyObject *func, PyObject *self, PyObject *obj)
415+
{
416+
if (self) {
417+
return PyObject_CallFunctionObjArgs(func, self, obj, NULL);
418+
}
419+
else {
420+
return PyObject_CallFunctionObjArgs(func, obj, NULL);
421+
}
422+
}
423+
424+
/*************************************************************************/
425+
363426
/* Internal data type used as the unpickling stack. */
364427
typedef struct {
365428
PyObject_VAR_HEAD
@@ -552,6 +615,8 @@ typedef struct PicklerObject {
552615
objects to support self-referential objects
553616
pickling. */
554617
PyObject *pers_func; /* persistent_id() method, can be NULL */
618+
PyObject *pers_func_self; /* borrowed reference to self if pers_func
619+
is an unbound method, NULL otherwise */
555620
PyObject *dispatch_table; /* private dispatch_table, can be NULL */
556621

557622
PyObject *write; /* write() method of the output stream. */
@@ -590,6 +655,8 @@ typedef struct UnpicklerObject {
590655
Py_ssize_t memo_len; /* Number of objects in the memo */
591656

592657
PyObject *pers_func; /* persistent_load() method, can be NULL. */
658+
PyObject *pers_func_self; /* borrowed reference to self if pers_func
659+
is an unbound method, NULL otherwise */
593660

594661
Py_buffer buffer;
595662
char *input_buffer;
@@ -3444,16 +3511,15 @@ save_type(PicklerObject *self, PyObject *obj)
34443511
}
34453512

34463513
static int
3447-
save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
3514+
save_pers(PicklerObject *self, PyObject *obj)
34483515
{
34493516
PyObject *pid = NULL;
34503517
int status = 0;
34513518

34523519
const char persid_op = PERSID;
34533520
const char binpersid_op = BINPERSID;
34543521

3455-
Py_INCREF(obj);
3456-
pid = _Pickle_FastCall(func, obj);
3522+
pid = call_method(self->pers_func, self->pers_func_self, obj);
34573523
if (pid == NULL)
34583524
return -1;
34593525

@@ -3831,7 +3897,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
38313897
0 if it did nothing successfully;
38323898
1 if a persistent id was saved.
38333899
*/
3834-
if ((status = save_pers(self, obj, self->pers_func)) != 0)
3900+
if ((status = save_pers(self, obj)) != 0)
38353901
goto done;
38363902
}
38373903

@@ -4246,13 +4312,10 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
42464312
self->fast_nesting = 0;
42474313
self->fast_memo = NULL;
42484314

4249-
self->pers_func = _PyObject_GetAttrId((PyObject *)self,
4250-
&PyId_persistent_id);
4251-
if (self->pers_func == NULL) {
4252-
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
4253-
return -1;
4254-
}
4255-
PyErr_Clear();
4315+
if (init_method_ref((PyObject *)self, &PyId_persistent_id,
4316+
&self->pers_func, &self->pers_func_self) < 0)
4317+
{
4318+
return -1;
42564319
}
42574320

42584321
self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
@@ -4519,11 +4582,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj)
45194582
static PyObject *
45204583
Pickler_get_persid(PicklerObject *self)
45214584
{
4522-
if (self->pers_func == NULL)
4585+
if (self->pers_func == NULL) {
45234586
PyErr_SetString(PyExc_AttributeError, "persistent_id");
4524-
else
4525-
Py_INCREF(self->pers_func);
4526-
return self->pers_func;
4587+
return NULL;
4588+
}
4589+
return reconstruct_method(self->pers_func, self->pers_func_self);
45274590
}
45284591

45294592
static int
@@ -4540,6 +4603,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value)
45404603
return -1;
45414604
}
45424605

4606+
self->pers_func_self = NULL;
45434607
Py_INCREF(value);
45444608
Py_XSETREF(self->pers_func, value);
45454609

@@ -5489,7 +5553,7 @@ load_stack_global(UnpicklerObject *self)
54895553
static int
54905554
load_persid(UnpicklerObject *self)
54915555
{
5492-
PyObject *pid;
5556+
PyObject *pid, *obj;
54935557
Py_ssize_t len;
54945558
char *s;
54955559

@@ -5509,13 +5573,12 @@ load_persid(UnpicklerObject *self)
55095573
return -1;
55105574
}
55115575

5512-
/* This does not leak since _Pickle_FastCall() steals the reference
5513-
to pid first. */
5514-
pid = _Pickle_FastCall(self->pers_func, pid);
5515-
if (pid == NULL)
5576+
obj = call_method(self->pers_func, self->pers_func_self, pid);
5577+
Py_DECREF(pid);
5578+
if (obj == NULL)
55165579
return -1;
55175580

5518-
PDATA_PUSH(self->stack, pid, -1);
5581+
PDATA_PUSH(self->stack, obj, -1);
55195582
return 0;
55205583
}
55215584
else {
@@ -5530,20 +5593,19 @@ load_persid(UnpicklerObject *self)
55305593
static int
55315594
load_binpersid(UnpicklerObject *self)
55325595
{
5533-
PyObject *pid;
5596+
PyObject *pid, *obj;
55345597

55355598
if (self->pers_func) {
55365599
PDATA_POP(self->stack, pid);
55375600
if (pid == NULL)
55385601
return -1;
55395602

5540-
/* This does not leak since _Pickle_FastCall() steals the
5541-
reference to pid first. */
5542-
pid = _Pickle_FastCall(self->pers_func, pid);
5543-
if (pid == NULL)
5603+
obj = call_method(self->pers_func, self->pers_func_self, pid);
5604+
Py_DECREF(pid);
5605+
if (obj == NULL)
55445606
return -1;
55455607

5546-
PDATA_PUSH(self->stack, pid, -1);
5608+
PDATA_PUSH(self->stack, obj, -1);
55475609
return 0;
55485610
}
55495611
else {
@@ -6690,13 +6752,10 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
66906752

66916753
self->fix_imports = fix_imports;
66926754

6693-
self->pers_func = _PyObject_GetAttrId((PyObject *)self,
6694-
&PyId_persistent_load);
6695-
if (self->pers_func == NULL) {
6696-
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
6697-
return -1;
6698-
}
6699-
PyErr_Clear();
6755+
if (init_method_ref((PyObject *)self, &PyId_persistent_load,
6756+
&self->pers_func, &self->pers_func_self) < 0)
6757+
{
6758+
return -1;
67006759
}
67016760

67026761
self->stack = (Pdata *)Pdata_New();
@@ -6983,11 +7042,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj)
69837042
static PyObject *
69847043
Unpickler_get_persload(UnpicklerObject *self)
69857044
{
6986-
if (self->pers_func == NULL)
7045+
if (self->pers_func == NULL) {
69877046
PyErr_SetString(PyExc_AttributeError, "persistent_load");
6988-
else
6989-
Py_INCREF(self->pers_func);
6990-
return self->pers_func;
7047+
return NULL;
7048+
}
7049+
return reconstruct_method(self->pers_func, self->pers_func_self);
69917050
}
69927051

69937052
static int
@@ -7005,6 +7064,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value)
70057064
return -1;
70067065
}
70077066

7067+
self->pers_func_self = NULL;
70087068
Py_INCREF(value);
70097069
Py_XSETREF(self->pers_func, value);
70107070

0 commit comments

Comments
 (0)