Skip to content

Commit 12d4b42

Browse files
author
William Grant
committed
Allow for multiple copies of the same function in different modules.
Previous iterations of the cache assumed a one-one-one function name to module mapping, however it's possible to end up with many modules which contain the same function (e.g due to a race condition when using multiple processes, but other scenarios could exist in the future). This commit separates the func_name (the id for the function) with the link_name (the unique id for a given function in a given module). This distinction is not exposed outside the cache - when asked for a target the cache chooses which version to return (currently just the first one it sees).
1 parent 12e35a8 commit 12d4b42

File tree

3 files changed

+372
-90
lines changed

3 files changed

+372
-90
lines changed

typed_python/compiler/compiler_cache.py

Lines changed: 104 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -49,77 +49,88 @@ class CompilerCache:
4949
when we first boot up, which could be slow. We could improve this substantially
5050
by making it possible to determine if a given function is in the cache by organizing
5151
the manifests by, say, function name.
52+
53+
Due to the potential for race conditions, we must distinguish between the following:
54+
func_name - The identifier for the function, based on its identity hash.
55+
link_name - The identifier for the specific realization of that function, which lives in a specific
56+
cache module.
5257
"""
5358
def __init__(self, cacheDir):
5459
self.cacheDir = cacheDir
5560

5661
ensureDirExists(cacheDir)
5762

5863
self.loadedBinarySharedObjects = Dict(str, LoadedBinarySharedObject)()
59-
self.nameToModuleHash = Dict(str, str)()
60-
64+
self.link_name_to_module_hash = Dict(str, str)()
6165
self.moduleManifestsLoaded = set()
62-
66+
# link_names with an associated module in loadedBinarySharedObjects
67+
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
68+
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
69+
self.targetsValidated = set()
70+
# link_name -> link_name
71+
self.function_dependency_graph = DirectedGraph()
72+
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
73+
self.global_dependencies = Dict(str, ListOf(str))()
74+
self.func_name_to_link_names = Dict(str, ListOf(str))()
6375
for moduleHash in os.listdir(self.cacheDir):
6476
if len(moduleHash) == 40:
6577
self.loadNameManifestFromStoredModuleByHash(moduleHash)
6678

67-
# the set of functions with an associated module in loadedBinarySharedObjects
68-
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
79+
def hasSymbol(self, func_name: str) -> bool:
80+
"""Returns true if there are any versions of `func_name` in the cache.
6981
70-
# the set of functions with linked and validated globals (i.e. ready to be run).
71-
self.targetsValidated = set()
82+
There may be multiple copies in different modules with different link_names.
83+
"""
84+
return any(link_name in self.link_name_to_module_hash for link_name in self.func_name_to_link_names.get(func_name, []))
7285

73-
self.function_dependency_graph = DirectedGraph()
74-
# dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75-
self.global_dependencies = Dict(str, ListOf(str))()
86+
def getTarget(self, func_name: str) -> TypedCallTarget:
87+
if not self.hasSymbol(func_name):
88+
raise ValueError(f'symbol not found for func_name {func_name}')
89+
link_name = self._select_link_name(func_name)
90+
self.loadForSymbol(link_name)
91+
return self.targetsLoaded[link_name]
7692

77-
def hasSymbol(self, linkName: str) -> bool:
78-
"""NB this will return True even if the linkName is ultimately unretrievable."""
79-
return linkName in self.nameToModuleHash
93+
def _generate_link_name(self, func_name: str, module_hash: str) -> str:
94+
return func_name + "." + module_hash
8095

81-
def getTarget(self, linkName: str) -> TypedCallTarget:
82-
if not self.hasSymbol(linkName):
83-
raise ValueError(f'symbol not found for linkName {linkName}')
84-
self.loadForSymbol(linkName)
85-
return self.targetsLoaded[linkName]
96+
def _select_link_name(self, func_name) -> str:
97+
"""choose a link name for a given func name.
8698
87-
def dependencies(self, linkName: str) -> Optional[List[str]]:
88-
"""Returns all the function names that `linkName` depends on"""
89-
return list(self.function_dependency_graph.outgoing(linkName))
99+
Currently we just choose the first available option.
100+
Throws a KeyError if func_name isn't in the cache.
101+
"""
102+
link_name_candidates = self.func_name_to_link_names[func_name]
103+
return link_name_candidates[0]
104+
105+
def dependencies(self, link_name: str) -> Optional[List[str]]:
106+
"""Returns all the function names that `link_name` depends on"""
107+
return list(self.function_dependency_graph.outgoing(link_name))
90108

91109
def loadForSymbol(self, linkName: str) -> None:
92-
"""Loads the whole module, and any submodules, into LoadedBinarySharedObjects"""
93-
moduleHash = self.nameToModuleHash[linkName]
110+
"""Loads the whole module, and any dependant modules, into LoadedBinarySharedObjects"""
111+
moduleHash = self.link_name_to_module_hash[linkName]
94112

95113
self.loadModuleByHash(moduleHash)
96114

97115
if linkName not in self.targetsValidated:
98-
dependantFuncs = self.dependencies(linkName) + [linkName]
99-
globalsToLink = {} # dict from modulehash to list of globals.
100-
for funcName in dependantFuncs:
101-
if funcName not in self.targetsValidated:
102-
funcModuleHash = self.nameToModuleHash[funcName]
103-
# append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104-
globalsToLink[funcModuleHash] = globalsToLink.get(funcModuleHash, []) + self.global_dependencies.get(funcName, [])
105-
106-
for moduleHash, globs in globalsToLink.items(): # this works because loadModuleByHash loads submodules too.
107-
if globs:
108-
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
109-
for x in globs
110-
}
111-
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
112-
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
113-
raise RuntimeError('failed to validate globals when loading:', linkName)
114-
115-
self.targetsValidated.update(dependantFuncs)
116+
self.targetsValidated.add(linkName)
117+
for dependant_func in self.dependencies(linkName):
118+
self.loadForSymbol(dependant_func)
119+
120+
globalsToLink = self.global_dependencies.get(linkName, [])
121+
if globalsToLink:
122+
definitionsToLink = {x: self.loadedBinarySharedObjects[moduleHash].serializedGlobalVariableDefinitions[x]
123+
for x in globalsToLink
124+
}
125+
self.loadedBinarySharedObjects[moduleHash].linkGlobalVariables(definitionsToLink)
126+
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
127+
raise RuntimeError('failed to validate globals when loading:', linkName)
116128

117129
def loadModuleByHash(self, moduleHash: str) -> None:
118130
"""Load a module by name.
119131
120-
As we load, place all the newly imported typed call targets into
121-
'nameToTypedCallTarget' so that the rest of the system knows what functions
122-
have been uncovered.
132+
Add the module contents to targetsLoaded, generate a LoadedBinarySharedObject,
133+
and update the function and global dependency graphs.
123134
"""
124135
if moduleHash in self.loadedBinarySharedObjects:
125136
return
@@ -128,6 +139,7 @@ def loadModuleByHash(self, moduleHash: str) -> None:
128139

129140
# TODO (Will) - store these names as module consts, use one .dat only
130141
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
142+
# func_name -> typedcalltarget
131143
callTargets = SerializationContext().deserialize(f.read())
132144

133145
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
@@ -156,45 +168,68 @@ def loadModuleByHash(self, moduleHash: str) -> None:
156168
serializedGlobalVarDefs,
157169
functionNameToNativeType,
158170
globalDependencies
159-
160171
).loadFromPath(modulePath)
161172

162173
self.loadedBinarySharedObjects[moduleHash] = loaded
163174

164-
self.targetsLoaded.update(callTargets)
175+
for func_name, callTarget in callTargets.items():
176+
link_name = self._generate_link_name(func_name, moduleHash)
177+
assert link_name not in self.targetsLoaded
178+
self.targetsLoaded[link_name] = callTarget
165179

166-
assert not any(key in self.global_dependencies for key in globalDependencies) # should only happen if there's a hash collision.
167-
self.global_dependencies.update(globalDependencies)
180+
link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}
168181

182+
assert not any(key in self.global_dependencies for key in link_name_global_dependencies)
183+
184+
self.global_dependencies.update(link_name_global_dependencies)
169185
# update the cache's dependency graph with our new edges.
170186
for function_name, dependant_function_name in dependency_edgelist:
171187
self.function_dependency_graph.addEdge(source=function_name, dest=dependant_function_name)
172188

173189
def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, dependencyEdgelist):
174190
"""Add new code to the compiler cache.
175191
192+
176193
Args:
177194
binarySharedObject: a BinarySharedObject containing the actual assembler
178195
we've compiled.
179-
nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
196+
nameToTypedCallTarget: a dict from func_name to TypedCallTarget telling us
180197
the formal python types for all the objects.
181-
linkDependencies: a set of linknames we depend on directly.
198+
linkDependencies: a set of func_names we depend on directly. (this becomes submodules)
182199
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183200
module.
201+
202+
TODO (Will): the notion of submodules/linkDependencies can be refactored out.
184203
"""
185-
dependentHashes = set()
186204

205+
hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest
206+
207+
# the linkDependencies and dependencyEdgelist are in terms of func_name.
208+
dependentHashes = set()
187209
for name in linkDependencies:
188-
dependentHashes.add(self.nameToModuleHash[name])
210+
link_name = self._select_link_name(name)
211+
dependentHashes.add(self.link_name_to_module_hash[link_name])
212+
213+
link_name_dependency_edgelist = []
214+
for source, dest in dependencyEdgelist:
215+
assert source in binarySharedObject.definedSymbols
216+
source_link_name = self._generate_link_name(source, hashToUse)
217+
if dest in binarySharedObject.definedSymbols:
218+
dest_link_name = self._generate_link_name(dest, hashToUse)
219+
else:
220+
dest_link_name = self._select_link_name(dest)
221+
link_name_dependency_edgelist.append([source_link_name, dest_link_name])
189222

190-
path, hashToUse = self.writeModuleToDisk(binarySharedObject, nameToTypedCallTarget, dependentHashes, dependencyEdgelist)
223+
path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist)
191224

192225
self.loadedBinarySharedObjects[hashToUse] = (
193226
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
194227
)
195228

196-
for n in binarySharedObject.definedSymbols:
197-
self.nameToModuleHash[n] = hashToUse
229+
for func_name in binarySharedObject.definedSymbols:
230+
link_name = self._generate_link_name(func_name, hashToUse)
231+
self.link_name_to_module_hash[link_name] = hashToUse
232+
self.func_name_to_link_names.setdefault(func_name, []).append(link_name)
198233

199234
# link & validate all globals for the new module
200235
self.loadedBinarySharedObjects[hashToUse].linkGlobalVariables()
@@ -208,20 +243,18 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
208243

209244
targetDir = os.path.join(self.cacheDir, moduleHash)
210245

211-
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
212-
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
213-
214-
for subHash in submodules:
215-
self.loadNameManifestFromStoredModuleByHash(subHash)
216-
246+
# TODO (Will) the name_manifest module_hash is the same throughout so this doesn't need to be a dict.
217247
with open(os.path.join(targetDir, "name_manifest.dat"), "rb") as f:
218-
self.nameToModuleHash.update(
219-
SerializationContext().deserialize(f.read(), Dict(str, str))
220-
)
248+
func_name_to_module_hash = SerializationContext().deserialize(f.read(), Dict(str, str))
249+
250+
for func_name, module_hash in func_name_to_module_hash.items():
251+
link_name = self._generate_link_name(func_name, module_hash)
252+
self.func_name_to_link_names.setdefault(func_name, []).append(link_name)
253+
self.link_name_to_module_hash[link_name] = module_hash
221254

222255
self.moduleManifestsLoaded.add(moduleHash)
223256

224-
def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodules, dependencyEdgelist):
257+
def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget, submodules, dependencyEdgelist):
225258
"""Write out a disk representation of this module.
226259
227260
This includes writing both the shared object, a manifest of the function names
@@ -235,7 +268,6 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
235268
to interact with the compiler cache simultaneously without relying on
236269
individual file-level locking.
237270
"""
238-
hashToUse = SerializationContext().sha_hash(str(uuid.uuid4())).hexdigest
239271

240272
targetDir = os.path.join(
241273
self.cacheDir,
@@ -264,23 +296,20 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
264296
for sourceName in manifest:
265297
f.write(sourceName + "\n")
266298

267-
# write the type manifest
268299
with open(os.path.join(tempTargetDir, "type_manifest.dat"), "wb") as f:
269300
f.write(SerializationContext().serialize(nameToTypedCallTarget))
270301

271-
# write the nativetype manifest
272302
with open(os.path.join(tempTargetDir, "native_type_manifest.dat"), "wb") as f:
273303
f.write(SerializationContext().serialize(binarySharedObject.functionTypes))
274304

275-
# write the type manifest
276305
with open(os.path.join(tempTargetDir, "globals_manifest.dat"), "wb") as f:
277306
f.write(SerializationContext().serialize(binarySharedObject.serializedGlobalVariableDefinitions))
278307

279308
with open(os.path.join(tempTargetDir, "submodules.dat"), "wb") as f:
280309
f.write(SerializationContext().serialize(ListOf(str)(submodules), ListOf(str)))
281310

282311
with open(os.path.join(tempTargetDir, "function_dependencies.dat"), "wb") as f:
283-
f.write(SerializationContext().serialize(dependencyEdgelist)) # might need a listof
312+
f.write(SerializationContext().serialize(dependencyEdgelist))
284313

285314
with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
286315
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))
@@ -293,14 +322,15 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
293322
else:
294323
shutil.rmtree(tempTargetDir)
295324

296-
return targetDir, hashToUse
325+
return targetDir
297326

298-
def function_pointer_by_name(self, linkName):
299-
moduleHash = self.nameToModuleHash.get(linkName)
327+
def function_pointer_by_name(self, func_name):
328+
linkName = self._select_link_name(func_name)
329+
moduleHash = self.link_name_to_module_hash.get(linkName)
300330
if moduleHash is None:
301331
raise Exception("Can't find a module for " + linkName)
302332

303333
if moduleHash not in self.loadedBinarySharedObjects:
304334
self.loadForSymbol(linkName)
305335

306-
return self.loadedBinarySharedObjects[moduleHash].functionPointers[linkName]
336+
return self.loadedBinarySharedObjects[moduleHash].functionPointers[func_name]

0 commit comments

Comments
 (0)