Skip to content

Commit 12e35a8

Browse files
author
William Grant
committed
Ensure that serializer can use the 'name' pickle protocol.
Pickle supports a protocol where __reduce__returns a string giving the global name. Implementing this behaviour lets us serialize numpy ufuncs. Also adjust installInflightFunctions to handle new load behaviour, fix an instability caused by not leaving LoadedModule objects in memory, and adjust alternative test. Also ensure that the new pickle protocol support works with 'local names' (e.g. dotted method names).
1 parent f0747ba commit 12e35a8

8 files changed

+136
-19
lines changed

typed_python/SerializationContext.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,31 @@
3131
import types
3232
import traceback
3333
import logging
34+
import numpy
35+
import pickle
3436

3537

3638
_badModuleCache = set()
3739

3840

41+
def pickledByStr(module_name: str, name: str) -> None:
42+
"""Generate the object given the module_name and name.
43+
44+
This mimics pickle's behavior when given a string from __reduce__. The
45+
string is interpreted as the name of a global variable, and pickle.whichmodules
46+
is used to search the module namespace, generating module_name.
47+
48+
Note that 'name' might contain '.' inside of it, since its a 'local name'.
49+
"""
50+
module = importlib.import_module(module_name)
51+
52+
instance = module
53+
for subName in name.split('.'):
54+
instance = getattr(instance, subName)
55+
56+
return instance
57+
58+
3959
def createFunctionWithLocalsAndGlobals(code, globals):
4060
if globals is None:
4161
globals = {}
@@ -708,26 +728,30 @@ def walkCodeObject(code):
708728
return (createFunctionWithLocalsAndGlobals, args, representation)
709729

710730
if not isinstance(inst, type) and hasattr(type(inst), '__reduce_ex__'):
711-
res = inst.__reduce_ex__(4)
731+
if isinstance(inst, numpy.ufunc):
732+
res = inst.__name__
733+
else:
734+
res = inst.__reduce_ex__(4)
712735

713-
# pickle supports a protocol where __reduce__ can return a string
714-
# giving a global name. We'll already find that separately, so we
715-
# don't want to handle it here. We ought to look at this in more detail
716-
# however
736+
# mimic pickle's behaviour when a string is received.
717737
if isinstance(res, str):
718-
return None
738+
name_tuple = (inst, res)
739+
module_name = pickle.whichmodule(*name_tuple)
740+
res = (pickledByStr, (module_name, res,), pickledByStr)
719741

720742
return res
721743

722744
if not isinstance(inst, type) and hasattr(type(inst), '__reduce__'):
723-
res = inst.__reduce__()
745+
if isinstance(inst, numpy.ufunc):
746+
res = inst.__name__
747+
else:
748+
res = inst.__reduce()
724749

725-
# pickle supports a protocol where __reduce__ can return a string
726-
# giving a global name. We'll already find that separately, so we
727-
# don't want to handle it here. We ought to look at this in more detail
728-
# however
750+
# mimic pickle's behaviour when a string is received.
729751
if isinstance(res, str):
730-
return None
752+
name_tuple = (inst, res)
753+
module_name = pickle.whichmodule(*name_tuple)
754+
res = (pickledByStr, (module_name, res,), pickledByStr)
731755

732756
return res
733757

@@ -736,6 +760,9 @@ def walkCodeObject(code):
736760
def setInstanceStateFromRepresentation(
737761
self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None
738762
):
763+
if representation is pickledByStr:
764+
return
765+
739766
if representation is reconstructTypeFunctionType:
740767
return
741768

typed_python/compiler/global_variable_definition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,4 @@ def __eq__(self, other):
8787
return self.name == other.name and self.type == other.type and self.metadata == other.metadata
8888

8989
def __str__(self):
90-
metadata_str = str(self.metadata) if len(str(self.metadata)) < 100 else str(self.metadata)[:100] + "..."
91-
return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={metadata_str})"
90+
return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={pad(str(self.metadata))})"

typed_python/compiler/llvm_compiler_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typed_python.compiler.module_definition import ModuleDefinition
2121
from typed_python.compiler.global_variable_definition import GlobalVariableMetadata
2222

23+
from typed_python.test_util import evaluateExprInFreshProcess
24+
2325
import pytest
2426
import ctypes
2527

@@ -131,3 +133,28 @@ def test_create_binary_shared_object():
131133
pointers[0].set(5)
132134

133135
assert loaded.functionPointers['__test_f_2']() == 5
136+
137+
138+
@pytest.mark.skipif('sys.platform=="darwin"')
139+
def test_loaded_modules_persist():
140+
"""
141+
Make sure that loaded modules are persisted in the converter state.
142+
143+
We have to maintain these references to avoid surprise segfaults - if this test fails,
144+
it should be because the GlobalVariableDefinition memory management has been refactored.
145+
"""
146+
147+
# compile a module
148+
xmodule = "\n".join([
149+
"@Entrypoint",
150+
"def f(x):",
151+
" return x + 1",
152+
"@Entrypoint",
153+
"def g(x):",
154+
" return f(x) * 100",
155+
"g(1000)",
156+
"def get_loaded_modules():",
157+
" return len(Runtime.singleton().converter.loadedUncachedModules)"
158+
])
159+
VERSION1 = {'x.py': xmodule}
160+
assert evaluateExprInFreshProcess(VERSION1, 'x.get_loaded_modules()') == 1

typed_python/compiler/loaded_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(self, functionPointers, serializedGlobalVariableDefinitions):
2828

2929
self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0))
3030

31+
self.installedGlobalVariableDefinitions = {}
32+
3133
@staticmethod
3234
def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool:
3335
"""Check that each global variable definition is sensible.
@@ -83,6 +85,8 @@ def linkGlobalVariables(self, variable_names: List[str] = None) -> None:
8385

8486
meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata
8587

88+
self.installedGlobalVariableDefinitions[i] = meta
89+
8690
if meta.matches.StringConstant:
8791
self.pointers[i].cast(str).initialize(meta.value)
8892

typed_python/compiler/python_to_native_converter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def __init__(self, llvmCompiler, compilerCache):
125125
self.llvmCompiler = llvmCompiler
126126
self.compilerCache = compilerCache
127127

128+
# all LoadedModule objects that we have created. We need to keep them alive so
129+
# that any python metadata objects the've created stay alive as well. Ultimately, this
130+
# may not be the place we put these objects (for instance, you could imagine a
131+
# 'dummy' compiler cache or something). But for now, we need to keep them alive.
132+
self.loadedUncachedModules = []
133+
128134
# if True, then insert additional code to check for undefined behavior.
129135
self.generateDebugChecks = False
130136

@@ -191,6 +197,7 @@ def buildAndLinkNewModule(self):
191197
if self.compilerCache is None:
192198
loadedModule = self.llvmCompiler.buildModule(targets)
193199
loadedModule.linkGlobalVariables()
200+
self.loadedUncachedModules.append(loadedModule)
194201
return
195202

196203
# get a set of function names that we depend on
@@ -926,7 +933,11 @@ def _installInflightFunctions(self):
926933
outboundTargets = []
927934
for outboundFuncId in self._dependencies.getNamesDependedOn(identifier):
928935
name = self._link_name_for_identity[outboundFuncId]
929-
outboundTargets.append(self._targets[name])
936+
target = self.getTarget(name)
937+
if target is not None:
938+
outboundTargets.append(target)
939+
else:
940+
raise RuntimeError(f'dependency not found for {name}.')
930941

931942
nativeFunction, actual_output_type = self._inflight_definitions.get(identifier)
932943

typed_python/compiler/tests/numpy_interaction_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typed_python import ListOf, Entrypoint
1+
from typed_python import ListOf, Entrypoint, SerializationContext
22
import numpy
33
import numpy.linalg
44

@@ -44,3 +44,12 @@ def test_listof_from_sliced_numpy_array():
4444
y = x[::2]
4545

4646
assert ListOf(int)(y) == [0, 2]
47+
48+
49+
def test_can_serialize_numpy_ufunc():
50+
assert numpy.sin == SerializationContext().deserialize(SerializationContext().serialize(numpy.sin))
51+
52+
53+
def test_can_serialize_numpy_array():
54+
x = numpy.ones(10)
55+
assert (x == SerializationContext().deserialize(SerializationContext().serialize(x))).all()

typed_python/compiler/tests/type_of_instances_compilation_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ def typeOfArg(x: C):
1717

1818
def test_type_of_alternative_is_specific():
1919
for members in [{}, {'a': int}]:
20-
A = Alternative("A", A=members)
20+
Alt = Alternative("Alt", A=members)
2121

2222
@Entrypoint
23-
def typeOfArg(x: A):
23+
def typeOfArg(x: Alt):
2424
return type(x)
2525

26-
assert typeOfArg(A.A()) is A.A
26+
assert typeOfArg(Alt.A()) is Alt.A
2727

2828

2929
def test_type_of_concrete_alternative_is_specific():

typed_python/types_serialization_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import sys
1616
import os
1717
import importlib
18+
from functools import lru_cache
19+
1820
from abc import ABC, abstractmethod, ABCMeta
1921
from typed_python.test_util import callFunctionInFreshProcess
2022
import typed_python.compiler.python_ast_util as python_ast_util
@@ -57,6 +59,13 @@
5759
module_level_testfun = dummy_test_module.testfunction
5860

5961

62+
class GlobalClassWithLruCache:
63+
@staticmethod
64+
@lru_cache(maxsize=None)
65+
def f(x):
66+
return x
67+
68+
6069
def moduleLevelFunctionUsedByExactlyOneSerializationTest():
6170
return "please don't touch me"
6271

@@ -3061,3 +3070,34 @@ def f(self):
30613070
print(x)
30623071
# TODO: make this True
30633072
# assert x[0].f.__closure__[0].cell_contents is x
3073+
3074+
def test_serialize_pyobj_with_custom_reduce(self):
3075+
class CustomReduceObject:
3076+
def __reduce__(self):
3077+
return 'CustomReduceObject'
3078+
3079+
assert CustomReduceObject == SerializationContext().deserialize(SerializationContext().serialize(CustomReduceObject))
3080+
3081+
def test_serialize_pyobj_in_MRTG_with_custom_reduce(self):
3082+
def getX():
3083+
class InnerCustomReduceObject:
3084+
def __reduce__(self):
3085+
return 'InnerCustomReduceObject'
3086+
3087+
def f(self):
3088+
return x
3089+
3090+
x = (InnerCustomReduceObject, InnerCustomReduceObject)
3091+
3092+
return x
3093+
3094+
x = callFunctionInFreshProcess(getX, (), showStdout=True)
3095+
3096+
assert x == SerializationContext().deserialize(SerializationContext().serialize(x))
3097+
3098+
def test_serialize_class_static_lru_cache(self):
3099+
s = SerializationContext()
3100+
3101+
assert (
3102+
s.deserialize(s.serialize(GlobalClassWithLruCache.f)) is GlobalClassWithLruCache.f
3103+
)

0 commit comments

Comments
 (0)