Skip to content

Commit c91bf74

Browse files
miss-islingtonserhiy-storchaka
authored andcommitted
bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (GH-4080) (#4653)
with the persistent_id() and persistent_load() methods. (cherry picked from commit 986375e)
1 parent 92a2c07 commit c91bf74

File tree

3 files changed

+164
-40
lines changed

3 files changed

+164
-40
lines changed

Lib/test/test_pickle.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import collections
77
import struct
88
import sys
9+
import weakref
910

1011
import unittest
1112
from test import support
@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
117118
pickler = pickle._Pickler
118119
unpickler = pickle._Unpickler
119120

121+
@support.cpython_only
122+
def test_pickler_reference_cycle(self):
123+
def check(Pickler):
124+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
125+
f = io.BytesIO()
126+
pickler = Pickler(f, proto)
127+
pickler.dump('abc')
128+
self.assertEqual(self.loads(f.getvalue()), 'abc')
129+
pickler = Pickler(io.BytesIO())
130+
self.assertEqual(pickler.persistent_id('def'), 'def')
131+
r = weakref.ref(pickler)
132+
del pickler
133+
self.assertIsNone(r())
134+
135+
class PersPickler(self.pickler):
136+
def persistent_id(subself, obj):
137+
return obj
138+
check(PersPickler)
139+
140+
class PersPickler(self.pickler):
141+
@classmethod
142+
def persistent_id(cls, obj):
143+
return obj
144+
check(PersPickler)
145+
146+
class PersPickler(self.pickler):
147+
@staticmethod
148+
def persistent_id(obj):
149+
return obj
150+
check(PersPickler)
151+
152+
@support.cpython_only
153+
def test_unpickler_reference_cycle(self):
154+
def check(Unpickler):
155+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
156+
unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
157+
self.assertEqual(unpickler.load(), 'abc')
158+
unpickler = Unpickler(io.BytesIO())
159+
self.assertEqual(unpickler.persistent_load('def'), 'def')
160+
r = weakref.ref(unpickler)
161+
del unpickler
162+
self.assertIsNone(r())
163+
164+
class PersUnpickler(self.unpickler):
165+
def persistent_load(subself, pid):
166+
return pid
167+
check(PersUnpickler)
168+
169+
class PersUnpickler(self.unpickler):
170+
@classmethod
171+
def persistent_load(cls, pid):
172+
return pid
173+
check(PersUnpickler)
174+
175+
class PersUnpickler(self.unpickler):
176+
@staticmethod
177+
def persistent_load(pid):
178+
return pid
179+
check(PersUnpickler)
180+
120181

121182
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
122183

@@ -197,7 +258,7 @@ class SizeofTests(unittest.TestCase):
197258
check_sizeof = support.check_sizeof
198259

199260
def test_pickler(self):
200-
basesize = support.calcobjsize('5P2n3i2n3iP')
261+
basesize = support.calcobjsize('6P2n3i2n3iP')
201262
p = _pickle.Pickler(io.BytesIO())
202263
self.assertEqual(object.__sizeof__(p), basesize)
203264
MT_size = struct.calcsize('3nP0n')
@@ -214,7 +275,7 @@ def test_pickler(self):
214275
0) # Write buffer is cleared after every dump().
215276

216277
def test_unpickler(self):
217-
basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i')
278+
basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
218279
unpickler = _pickle.Unpickler
219280
P = struct.calcsize('P') # Size of memo table entry.
220281
n = struct.calcsize('n') # Size of mark table entry.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Instances of pickle.Pickler subclass with the persistent_id() method and
2+
pickle.Unpickler subclass with the persistent_load() method no longer create
3+
reference cycles.

Modules/_pickle.c

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj)
353353

354354
/*************************************************************************/
355355

356+
/* Retrieve and deconstruct a method for avoiding a reference cycle
357+
(pickler -> bound method of pickler -> pickler) */
358+
static int
359+
init_method_ref(PyObject *self, _Py_Identifier *name,
360+
PyObject **method_func, PyObject **method_self)
361+
{
362+
PyObject *func, *func2;
363+
364+
/* *method_func and *method_self should be consistent. All refcount decrements
365+
should be occurred after setting *method_self and *method_func. */
366+
func = _PyObject_GetAttrId(self, name);
367+
if (func == NULL) {
368+
*method_self = NULL;
369+
Py_CLEAR(*method_func);
370+
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
371+
return -1;
372+
}
373+
PyErr_Clear();
374+
return 0;
375+
}
376+
377+
if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) {
378+
/* Deconstruct a bound Python method */
379+
func2 = PyMethod_GET_FUNCTION(func);
380+
Py_INCREF(func2);
381+
*method_self = self; /* borrowed */
382+
Py_XSETREF(*method_func, func2);
383+
Py_DECREF(func);
384+
return 0;
385+
}
386+
else {
387+
*method_self = NULL;
388+
Py_XSETREF(*method_func, func);
389+
return 0;
390+
}
391+
}
392+
393+
/* Bind a method if it was deconstructed */
394+
static PyObject *
395+
reconstruct_method(PyObject *func, PyObject *self)
396+
{
397+
if (self) {
398+
return PyMethod_New(func, self);
399+
}
400+
else {
401+
Py_INCREF(func);
402+
return func;
403+
}
404+
}
405+
406+
static PyObject *
407+
call_method(PyObject *func, PyObject *self, PyObject *obj)
408+
{
409+
if (self) {
410+
return PyObject_CallFunctionObjArgs(func, self, obj, NULL);
411+
}
412+
else {
413+
return PyObject_CallFunctionObjArgs(func, obj, NULL);
414+
}
415+
}
416+
417+
/*************************************************************************/
418+
356419
/* Internal data type used as the unpickling stack. */
357420
typedef struct {
358421
PyObject_VAR_HEAD
@@ -545,6 +608,8 @@ typedef struct PicklerObject {
545608
objects to support self-referential objects
546609
pickling. */
547610
PyObject *pers_func; /* persistent_id() method, can be NULL */
611+
PyObject *pers_func_self; /* borrowed reference to self if pers_func
612+
is an unbound method, NULL otherwise */
548613
PyObject *dispatch_table; /* private dispatch_table, can be NULL */
549614

550615
PyObject *write; /* write() method of the output stream. */
@@ -583,6 +648,8 @@ typedef struct UnpicklerObject {
583648
Py_ssize_t memo_len; /* Number of objects in the memo */
584649

585650
PyObject *pers_func; /* persistent_load() method, can be NULL. */
651+
PyObject *pers_func_self; /* borrowed reference to self if pers_func
652+
is an unbound method, NULL otherwise */
586653

587654
Py_buffer buffer;
588655
char *input_buffer;
@@ -3401,16 +3468,15 @@ save_type(PicklerObject *self, PyObject *obj)
34013468
}
34023469

34033470
static int
3404-
save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
3471+
save_pers(PicklerObject *self, PyObject *obj)
34053472
{
34063473
PyObject *pid = NULL;
34073474
int status = 0;
34083475

34093476
const char persid_op = PERSID;
34103477
const char binpersid_op = BINPERSID;
34113478

3412-
Py_INCREF(obj);
3413-
pid = _Pickle_FastCall(func, obj);
3479+
pid = call_method(self->pers_func, self->pers_func_self, obj);
34143480
if (pid == NULL)
34153481
return -1;
34163482

@@ -3788,7 +3854,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
37883854
0 if it did nothing successfully;
37893855
1 if a persistent id was saved.
37903856
*/
3791-
if ((status = save_pers(self, obj, self->pers_func)) != 0)
3857+
if ((status = save_pers(self, obj)) != 0)
37923858
goto done;
37933859
}
37943860

@@ -4203,13 +4269,10 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
42034269
self->fast_nesting = 0;
42044270
self->fast_memo = NULL;
42054271

4206-
self->pers_func = _PyObject_GetAttrId((PyObject *)self,
4207-
&PyId_persistent_id);
4208-
if (self->pers_func == NULL) {
4209-
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
4210-
return -1;
4211-
}
4212-
PyErr_Clear();
4272+
if (init_method_ref((PyObject *)self, &PyId_persistent_id,
4273+
&self->pers_func, &self->pers_func_self) < 0)
4274+
{
4275+
return -1;
42134276
}
42144277

42154278
self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
@@ -4476,11 +4539,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj)
44764539
static PyObject *
44774540
Pickler_get_persid(PicklerObject *self)
44784541
{
4479-
if (self->pers_func == NULL)
4542+
if (self->pers_func == NULL) {
44804543
PyErr_SetString(PyExc_AttributeError, "persistent_id");
4481-
else
4482-
Py_INCREF(self->pers_func);
4483-
return self->pers_func;
4544+
return NULL;
4545+
}
4546+
return reconstruct_method(self->pers_func, self->pers_func_self);
44844547
}
44854548

44864549
static int
@@ -4497,6 +4560,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value)
44974560
return -1;
44984561
}
44994562

4563+
self->pers_func_self = NULL;
45004564
Py_INCREF(value);
45014565
Py_XSETREF(self->pers_func, value);
45024566

@@ -5446,7 +5510,7 @@ load_stack_global(UnpicklerObject *self)
54465510
static int
54475511
load_persid(UnpicklerObject *self)
54485512
{
5449-
PyObject *pid;
5513+
PyObject *pid, *obj;
54505514
Py_ssize_t len;
54515515
char *s;
54525516

@@ -5466,13 +5530,12 @@ load_persid(UnpicklerObject *self)
54665530
return -1;
54675531
}
54685532

5469-
/* This does not leak since _Pickle_FastCall() steals the reference
5470-
to pid first. */
5471-
pid = _Pickle_FastCall(self->pers_func, pid);
5472-
if (pid == NULL)
5533+
obj = call_method(self->pers_func, self->pers_func_self, pid);
5534+
Py_DECREF(pid);
5535+
if (obj == NULL)
54735536
return -1;
54745537

5475-
PDATA_PUSH(self->stack, pid, -1);
5538+
PDATA_PUSH(self->stack, obj, -1);
54765539
return 0;
54775540
}
54785541
else {
@@ -5487,20 +5550,19 @@ load_persid(UnpicklerObject *self)
54875550
static int
54885551
load_binpersid(UnpicklerObject *self)
54895552
{
5490-
PyObject *pid;
5553+
PyObject *pid, *obj;
54915554

54925555
if (self->pers_func) {
54935556
PDATA_POP(self->stack, pid);
54945557
if (pid == NULL)
54955558
return -1;
54965559

5497-
/* This does not leak since _Pickle_FastCall() steals the
5498-
reference to pid first. */
5499-
pid = _Pickle_FastCall(self->pers_func, pid);
5500-
if (pid == NULL)
5560+
obj = call_method(self->pers_func, self->pers_func_self, pid);
5561+
Py_DECREF(pid);
5562+
if (obj == NULL)
55015563
return -1;
55025564

5503-
PDATA_PUSH(self->stack, pid, -1);
5565+
PDATA_PUSH(self->stack, obj, -1);
55045566
return 0;
55055567
}
55065568
else {
@@ -6637,13 +6699,10 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
66376699

66386700
self->fix_imports = fix_imports;
66396701

6640-
self->pers_func = _PyObject_GetAttrId((PyObject *)self,
6641-
&PyId_persistent_load);
6642-
if (self->pers_func == NULL) {
6643-
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
6644-
return -1;
6645-
}
6646-
PyErr_Clear();
6702+
if (init_method_ref((PyObject *)self, &PyId_persistent_load,
6703+
&self->pers_func, &self->pers_func_self) < 0)
6704+
{
6705+
return -1;
66476706
}
66486707

66496708
self->stack = (Pdata *)Pdata_New();
@@ -6930,11 +6989,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj)
69306989
static PyObject *
69316990
Unpickler_get_persload(UnpicklerObject *self)
69326991
{
6933-
if (self->pers_func == NULL)
6992+
if (self->pers_func == NULL) {
69346993
PyErr_SetString(PyExc_AttributeError, "persistent_load");
6935-
else
6936-
Py_INCREF(self->pers_func);
6937-
return self->pers_func;
6994+
return NULL;
6995+
}
6996+
return reconstruct_method(self->pers_func, self->pers_func_self);
69386997
}
69396998

69406999
static int
@@ -6952,6 +7011,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value)
69527011
return -1;
69537012
}
69547013

7014+
self->pers_func_self = NULL;
69557015
Py_INCREF(value);
69567016
Py_XSETREF(self->pers_func, value);
69577017

0 commit comments

Comments
 (0)