diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index b1107ba92..f3c0590de 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -635,11 +635,12 @@ def extract_func_data(self, func): base_globals = self.globals_ref.get(id(func.__globals__), None) if base_globals is None: - # For functions defined in __main__, use vars(__main__) for - # base_global. This is necessary to share the global variables - # across multiple functions in this module. - if func.__module__ == "__main__": - base_globals = "__main__" + # For functions defined in a well behaved module use + # vars(func.__module__) for base_globals. This is necessary to + # share the global variables across multiple pickled functions from + # this module. + if hasattr(func, '__module__') and func.__module__ is not None: + base_globals = func.__module__ else: base_globals = {} self.globals_ref[id(func.__globals__)] = base_globals diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index f1c3ea55b..534993e2c 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -47,6 +47,9 @@ from .testutils import assert_run_python_script +_TEST_GLOBAL_VARIABLE = "default_value" + + class RaiserOnPickle(object): def __init__(self, exc): @@ -887,6 +890,39 @@ def f1(): clone_func=clone_func) assert_run_python_script(textwrap.dedent(code)) + def test_closure_interacting_with_a_global_variable(self): + global _TEST_GLOBAL_VARIABLE + orig_value = _TEST_GLOBAL_VARIABLE + try: + def f0(): + global _TEST_GLOBAL_VARIABLE + _TEST_GLOBAL_VARIABLE = "changed_by_f0" + + def f1(): + return _TEST_GLOBAL_VARIABLE + + cloned_f0 = cloudpickle.loads(cloudpickle.dumps( + f0, protocol=self.protocol)) + cloned_f1 = cloudpickle.loads(cloudpickle.dumps( + f1, protocol=self.protocol)) + pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol) + + # Change the value of the global variable + cloned_f0() + + # Ensure that the global variable is the same for another function + result_f1 = cloned_f1() + assert result_f1 == "changed_by_f0", result_f1 + assert f1() == result_f1 + + # Ensure that unpickling the global variable does not change its + # value + result_pickled_f1 = cloudpickle.loads(pickled_f1)() + assert result_pickled_f1 == "changed_by_f0", result_pickled_f1 + finally: + _TEST_GLOBAL_VARIABLE = orig_value + + @pytest.mark.skipif(sys.version_info >= (3, 0), reason="hardcoded pickle bytes for 2.7") def test_function_pickle_compat_0_4_0(self):