Skip to content

Commit 8eaf637

Browse files
pierreglaserogrisel
authored andcommitted
FIX Handling of global variables by locally defined functions (#198)
1 parent 81a7a8c commit 8eaf637

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

CHANGES.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
21
master
32
======
43

4+
- Ensure that unpickling a locally defined function that accesses the global variables
5+
of a module does not reset the values of the global variables if they are already initialized.
6+
([issue #187](https://github.com/cloudpipe/cloudpickle/issues/187))
7+
58

69
0.5.5
710
=====
@@ -19,7 +22,6 @@ master
1922
variables ([issue #187](
2023
https://github.com/cloudpipe/cloudpickle/issues/187)).
2124

22-
2325
0.5.3
2426
=====
2527
- Fixed a crash in Python 2 when serializing non-hashable instancemethods of built-in

cloudpickle/cloudpickle.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -635,11 +635,12 @@ def extract_func_data(self, func):
635635

636636
base_globals = self.globals_ref.get(id(func.__globals__), None)
637637
if base_globals is None:
638-
# For functions defined in __main__, use vars(__main__) for
639-
# base_global. This is necessary to share the global variables
640-
# across multiple functions in this module.
641-
if func.__module__ == "__main__":
642-
base_globals = "__main__"
638+
# For functions defined in a well behaved module use
639+
# vars(func.__module__) for base_globals. This is necessary to
640+
# share the global variables across multiple pickled functions from
641+
# this module.
642+
if hasattr(func, '__module__') and func.__module__ is not None:
643+
base_globals = func.__module__
643644
else:
644645
base_globals = {}
645646
self.globals_ref[id(func.__globals__)] = base_globals
@@ -934,7 +935,6 @@ def subimport(name):
934935
def dynamic_subimport(name, vars):
935936
mod = imp.new_module(name)
936937
mod.__dict__.update(vars)
937-
sys.modules[name] = mod
938938
return mod
939939

940940

@@ -1090,7 +1090,13 @@ def _make_skel_func(code, cell_count, base_globals=None):
10901090
if base_globals is None:
10911091
base_globals = {}
10921092
elif isinstance(base_globals, str):
1093-
base_globals = vars(sys.modules[base_globals])
1093+
if sys.modules.get(base_globals, None) is not None:
1094+
# this checks if we can import the previous environment the object
1095+
# lived in
1096+
base_globals = vars(sys.modules[base_globals])
1097+
else:
1098+
base_globals = {}
1099+
10941100
base_globals['__builtins__'] = __builtins__
10951101

10961102
closure = (

tests/cloudpickle_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import textwrap
1818
import unittest
1919
import weakref
20+
import os
2021

2122
try:
2223
from StringIO import StringIO
@@ -47,6 +48,9 @@
4748
from .testutils import assert_run_python_script
4849

4950

51+
_TEST_GLOBAL_VARIABLE = "default_value"
52+
53+
5054
class RaiserOnPickle(object):
5155

5256
def __init__(self, exc):
@@ -436,6 +440,74 @@ def method(self, x):
436440
mod1, mod2 = pickle_depickle([mod, mod])
437441
self.assertEqual(id(mod1), id(mod2))
438442

443+
def test_load_dynamic_module_in_grandchild_process(self):
444+
# Make sure that when loaded, a dynamic module preserves its dynamic
445+
# property. Otherwise, this will lead to an ImportError if pickled in
446+
# the child process and reloaded in another one.
447+
448+
# We create a new dynamic module
449+
mod = imp.new_module('mod')
450+
code = '''
451+
x = 1
452+
'''
453+
exec(textwrap.dedent(code), mod.__dict__)
454+
455+
# This script will be ran in a separate child process. It will import
456+
# the pickled dynamic module, and then re-pickle it under a new name.
457+
# Finally, it will create a child process that will load the re-pickled
458+
# dynamic module.
459+
parent_process_module_file = 'dynamic_module_from_parent_process.pkl'
460+
child_process_module_file = 'dynamic_module_from_child_process.pkl'
461+
child_process_script = '''
462+
import pickle
463+
import textwrap
464+
465+
import cloudpickle
466+
from testutils import assert_run_python_script
467+
468+
469+
child_of_child_process_script = {child_of_child_process_script}
470+
471+
with open('{parent_process_module_file}', 'rb') as f:
472+
mod = pickle.load(f)
473+
474+
with open('{child_process_module_file}', 'wb') as f:
475+
cloudpickle.dump(mod, f)
476+
477+
assert_run_python_script(textwrap.dedent(child_of_child_process_script))
478+
'''
479+
480+
# The script ran by the process created by the child process
481+
child_of_child_process_script = """ '''
482+
import pickle
483+
with open('{child_process_module_file}','rb') as fid:
484+
mod = pickle.load(fid)
485+
''' """
486+
487+
# Filling the two scripts with the pickled modules filepaths and,
488+
# for the first child process, the script to be executed by its
489+
# own child process.
490+
child_of_child_process_script = child_of_child_process_script.format(
491+
child_process_module_file=child_process_module_file)
492+
493+
child_process_script = child_process_script.format(
494+
parent_process_module_file=parent_process_module_file,
495+
child_process_module_file=child_process_module_file,
496+
child_of_child_process_script=child_of_child_process_script)
497+
498+
try:
499+
with open(parent_process_module_file, 'wb') as fid:
500+
cloudpickle.dump(mod, fid)
501+
502+
assert_run_python_script(textwrap.dedent(child_process_script))
503+
504+
finally:
505+
# Remove temporary created files
506+
if os.path.exists(parent_process_module_file):
507+
os.unlink(parent_process_module_file)
508+
if os.path.exists(child_process_module_file):
509+
os.unlink(child_process_module_file)
510+
439511
def test_find_module(self):
440512
import pickle # ensure this test is decoupled from global imports
441513
_find_module('pickle')
@@ -887,6 +959,40 @@ def f1():
887959
clone_func=clone_func)
888960
assert_run_python_script(textwrap.dedent(code))
889961

962+
def test_closure_interacting_with_a_global_variable(self):
963+
global _TEST_GLOBAL_VARIABLE
964+
assert _TEST_GLOBAL_VARIABLE == "default_value"
965+
orig_value = _TEST_GLOBAL_VARIABLE
966+
try:
967+
def f0():
968+
global _TEST_GLOBAL_VARIABLE
969+
_TEST_GLOBAL_VARIABLE = "changed_by_f0"
970+
971+
def f1():
972+
return _TEST_GLOBAL_VARIABLE
973+
974+
cloned_f0 = cloudpickle.loads(cloudpickle.dumps(
975+
f0, protocol=self.protocol))
976+
cloned_f1 = cloudpickle.loads(cloudpickle.dumps(
977+
f1, protocol=self.protocol))
978+
pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol)
979+
980+
# Change the value of the global variable
981+
cloned_f0()
982+
assert _TEST_GLOBAL_VARIABLE == "changed_by_f0"
983+
984+
# Ensure that the global variable is the same for another function
985+
result_cloned_f1 = cloned_f1()
986+
assert result_cloned_f1 == "changed_by_f0", result_cloned_f1
987+
assert f1() == result_cloned_f1
988+
989+
# Ensure that unpickling the global variable does not change its
990+
# value
991+
result_pickled_f1 = cloudpickle.loads(pickled_f1)()
992+
assert result_pickled_f1 == "changed_by_f0", result_pickled_f1
993+
finally:
994+
_TEST_GLOBAL_VARIABLE = orig_value
995+
890996
@pytest.mark.skipif(sys.version_info >= (3, 0),
891997
reason="hardcoded pickle bytes for 2.7")
892998
def test_function_pickle_compat_0_4_0(self):

0 commit comments

Comments
 (0)