@@ -49,77 +49,88 @@ class CompilerCache:
49
49
when we first boot up, which could be slow. We could improve this substantially
50
50
by making it possible to determine if a given function is in the cache by organizing
51
51
the manifests by, say, function name.
52
+
53
+ Due to the potential for race conditions, we must distinguish between the following:
54
+ func_name - The identifier for the function, based on its identity hash.
55
+ link_name - The identifier for the specific realization of that function, which lives in a specific
56
+ cache module.
52
57
"""
53
58
def __init__ (self , cacheDir ):
54
59
self .cacheDir = cacheDir
55
60
56
61
ensureDirExists (cacheDir )
57
62
58
63
self .loadedBinarySharedObjects = Dict (str , LoadedBinarySharedObject )()
59
- self .nameToModuleHash = Dict (str , str )()
60
-
64
+ self .link_name_to_module_hash = Dict (str , str )()
61
65
self .moduleManifestsLoaded = set ()
62
-
66
+ # link_names with an associated module in loadedBinarySharedObjects
67
+ self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
68
+ # the set of link_names for functions with linked and validated globals (i.e. ready to be run).
69
+ self .targetsValidated = set ()
70
+ # link_name -> link_name
71
+ self .function_dependency_graph = DirectedGraph ()
72
+ # dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
73
+ self .global_dependencies = Dict (str , ListOf (str ))()
74
+ self .func_name_to_link_names = Dict (str , ListOf (str ))()
63
75
for moduleHash in os .listdir (self .cacheDir ):
64
76
if len (moduleHash ) == 40 :
65
77
self .loadNameManifestFromStoredModuleByHash (moduleHash )
66
78
67
- # the set of functions with an associated module in loadedBinarySharedObjects
68
- self . targetsLoaded : Dict [ str , TypedCallTarget ] = {}
79
+ def hasSymbol ( self , func_name : str ) -> bool :
80
+ """Returns true if there are any versions of `func_name` in the cache.
69
81
70
- # the set of functions with linked and validated globals (i.e. ready to be run).
71
- self .targetsValidated = set ()
82
+ There may be multiple copies in different modules with different link_names.
83
+ """
84
+ return any (link_name in self .link_name_to_module_hash for link_name in self .func_name_to_link_names .get (func_name , []))
72
85
73
- self .function_dependency_graph = DirectedGraph ()
74
- # dict from function linkname to list of global names (should be llvm keys in serialisedGlobalDefinitions)
75
- self .global_dependencies = Dict (str , ListOf (str ))()
86
+ def getTarget (self , func_name : str ) -> TypedCallTarget :
87
+ if not self .hasSymbol (func_name ):
88
+ raise ValueError (f'symbol not found for func_name { func_name } ' )
89
+ link_name = self ._select_link_name (func_name )
90
+ self .loadForSymbol (link_name )
91
+ return self .targetsLoaded [link_name ]
76
92
77
- def hasSymbol (self , linkName : str ) -> bool :
78
- """NB this will return True even if the linkName is ultimately unretrievable."""
79
- return linkName in self .nameToModuleHash
93
+ def _generate_link_name (self , func_name : str , module_hash : str ) -> str :
94
+ return func_name + "." + module_hash
80
95
81
- def getTarget (self , linkName : str ) -> TypedCallTarget :
82
- if not self .hasSymbol (linkName ):
83
- raise ValueError (f'symbol not found for linkName { linkName } ' )
84
- self .loadForSymbol (linkName )
85
- return self .targetsLoaded [linkName ]
96
+ def _select_link_name (self , func_name ) -> str :
97
+ """choose a link name for a given func name.
86
98
87
- def dependencies (self , linkName : str ) -> Optional [List [str ]]:
88
- """Returns all the function names that `linkName` depends on"""
89
- return list (self .function_dependency_graph .outgoing (linkName ))
99
+ Currently we just choose the first available option.
100
+ Throws a KeyError if func_name isn't in the cache.
101
+ """
102
+ link_name_candidates = self .func_name_to_link_names [func_name ]
103
+ return link_name_candidates [0 ]
104
+
105
+ def dependencies (self , link_name : str ) -> Optional [List [str ]]:
106
+ """Returns all the function names that `link_name` depends on"""
107
+ return list (self .function_dependency_graph .outgoing (link_name ))
90
108
91
109
def loadForSymbol (self , linkName : str ) -> None :
92
- """Loads the whole module, and any submodules , into LoadedBinarySharedObjects"""
93
- moduleHash = self .nameToModuleHash [linkName ]
110
+ """Loads the whole module, and any dependant modules , into LoadedBinarySharedObjects"""
111
+ moduleHash = self .link_name_to_module_hash [linkName ]
94
112
95
113
self .loadModuleByHash (moduleHash )
96
114
97
115
if linkName not in self .targetsValidated :
98
- dependantFuncs = self .dependencies (linkName ) + [linkName ]
99
- globalsToLink = {} # dict from modulehash to list of globals.
100
- for funcName in dependantFuncs :
101
- if funcName not in self .targetsValidated :
102
- funcModuleHash = self .nameToModuleHash [funcName ]
103
- # append to the list of globals to link for a given module. TODO: optimise this, don't double-link.
104
- globalsToLink [funcModuleHash ] = globalsToLink .get (funcModuleHash , []) + self .global_dependencies .get (funcName , [])
105
-
106
- for moduleHash , globs in globalsToLink .items (): # this works because loadModuleByHash loads submodules too.
107
- if globs :
108
- definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
109
- for x in globs
110
- }
111
- self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
112
- if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
113
- raise RuntimeError ('failed to validate globals when loading:' , linkName )
114
-
115
- self .targetsValidated .update (dependantFuncs )
116
+ self .targetsValidated .add (linkName )
117
+ for dependant_func in self .dependencies (linkName ):
118
+ self .loadForSymbol (dependant_func )
119
+
120
+ globalsToLink = self .global_dependencies .get (linkName , [])
121
+ if globalsToLink :
122
+ definitionsToLink = {x : self .loadedBinarySharedObjects [moduleHash ].serializedGlobalVariableDefinitions [x ]
123
+ for x in globalsToLink
124
+ }
125
+ self .loadedBinarySharedObjects [moduleHash ].linkGlobalVariables (definitionsToLink )
126
+ if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
127
+ raise RuntimeError ('failed to validate globals when loading:' , linkName )
116
128
117
129
def loadModuleByHash (self , moduleHash : str ) -> None :
118
130
"""Load a module by name.
119
131
120
- As we load, place all the newly imported typed call targets into
121
- 'nameToTypedCallTarget' so that the rest of the system knows what functions
122
- have been uncovered.
132
+ Add the module contents to targetsLoaded, generate a LoadedBinarySharedObject,
133
+ and update the function and global dependency graphs.
123
134
"""
124
135
if moduleHash in self .loadedBinarySharedObjects :
125
136
return
@@ -128,6 +139,7 @@ def loadModuleByHash(self, moduleHash: str) -> None:
128
139
129
140
# TODO (Will) - store these names as module consts, use one .dat only
130
141
with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
142
+ # func_name -> typedcalltarget
131
143
callTargets = SerializationContext ().deserialize (f .read ())
132
144
133
145
with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
@@ -156,45 +168,68 @@ def loadModuleByHash(self, moduleHash: str) -> None:
156
168
serializedGlobalVarDefs ,
157
169
functionNameToNativeType ,
158
170
globalDependencies
159
-
160
171
).loadFromPath (modulePath )
161
172
162
173
self .loadedBinarySharedObjects [moduleHash ] = loaded
163
174
164
- self .targetsLoaded .update (callTargets )
175
+ for func_name , callTarget in callTargets .items ():
176
+ link_name = self ._generate_link_name (func_name , moduleHash )
177
+ assert link_name not in self .targetsLoaded
178
+ self .targetsLoaded [link_name ] = callTarget
165
179
166
- assert not any (key in self .global_dependencies for key in globalDependencies ) # should only happen if there's a hash collision.
167
- self .global_dependencies .update (globalDependencies )
180
+ link_name_global_dependencies = {self ._generate_link_name (x , moduleHash ): y for x , y in globalDependencies .items ()}
168
181
182
+ assert not any (key in self .global_dependencies for key in link_name_global_dependencies )
183
+
184
+ self .global_dependencies .update (link_name_global_dependencies )
169
185
# update the cache's dependency graph with our new edges.
170
186
for function_name , dependant_function_name in dependency_edgelist :
171
187
self .function_dependency_graph .addEdge (source = function_name , dest = dependant_function_name )
172
188
173
189
def addModule (self , binarySharedObject , nameToTypedCallTarget , linkDependencies , dependencyEdgelist ):
174
190
"""Add new code to the compiler cache.
175
191
192
+
176
193
Args:
177
194
binarySharedObject: a BinarySharedObject containing the actual assembler
178
195
we've compiled.
179
- nameToTypedCallTarget: a dict from linkname to TypedCallTarget telling us
196
+ nameToTypedCallTarget: a dict from func_name to TypedCallTarget telling us
180
197
the formal python types for all the objects.
181
- linkDependencies: a set of linknames we depend on directly.
198
+ linkDependencies: a set of func_names we depend on directly. (this becomes submodules)
182
199
dependencyEdgelist (list): a list of source, dest pairs giving the set of dependency graph for the
183
200
module.
201
+
202
+ TODO (Will): the notion of submodules/linkDependencies can be refactored out.
184
203
"""
185
- dependentHashes = set ()
186
204
205
+ hashToUse = SerializationContext ().sha_hash (str (uuid .uuid4 ())).hexdigest
206
+
207
+ # the linkDependencies and dependencyEdgelist are in terms of func_name.
208
+ dependentHashes = set ()
187
209
for name in linkDependencies :
188
- dependentHashes .add (self .nameToModuleHash [name ])
210
+ link_name = self ._select_link_name (name )
211
+ dependentHashes .add (self .link_name_to_module_hash [link_name ])
212
+
213
+ link_name_dependency_edgelist = []
214
+ for source , dest in dependencyEdgelist :
215
+ assert source in binarySharedObject .definedSymbols
216
+ source_link_name = self ._generate_link_name (source , hashToUse )
217
+ if dest in binarySharedObject .definedSymbols :
218
+ dest_link_name = self ._generate_link_name (dest , hashToUse )
219
+ else :
220
+ dest_link_name = self ._select_link_name (dest )
221
+ link_name_dependency_edgelist .append ([source_link_name , dest_link_name ])
189
222
190
- path , hashToUse = self .writeModuleToDisk (binarySharedObject , nameToTypedCallTarget , dependentHashes , dependencyEdgelist )
223
+ path = self .writeModuleToDisk (binarySharedObject , hashToUse , nameToTypedCallTarget , dependentHashes , link_name_dependency_edgelist )
191
224
192
225
self .loadedBinarySharedObjects [hashToUse ] = (
193
226
binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
194
227
)
195
228
196
- for n in binarySharedObject .definedSymbols :
197
- self .nameToModuleHash [n ] = hashToUse
229
+ for func_name in binarySharedObject .definedSymbols :
230
+ link_name = self ._generate_link_name (func_name , hashToUse )
231
+ self .link_name_to_module_hash [link_name ] = hashToUse
232
+ self .func_name_to_link_names .setdefault (func_name , []).append (link_name )
198
233
199
234
# link & validate all globals for the new module
200
235
self .loadedBinarySharedObjects [hashToUse ].linkGlobalVariables ()
@@ -208,20 +243,18 @@ def loadNameManifestFromStoredModuleByHash(self, moduleHash) -> None:
208
243
209
244
targetDir = os .path .join (self .cacheDir , moduleHash )
210
245
211
- with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
212
- submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
213
-
214
- for subHash in submodules :
215
- self .loadNameManifestFromStoredModuleByHash (subHash )
216
-
246
+ # TODO (Will) the name_manifest module_hash is the same throughout so this doesn't need to be a dict.
217
247
with open (os .path .join (targetDir , "name_manifest.dat" ), "rb" ) as f :
218
- self .nameToModuleHash .update (
219
- SerializationContext ().deserialize (f .read (), Dict (str , str ))
220
- )
248
+ func_name_to_module_hash = SerializationContext ().deserialize (f .read (), Dict (str , str ))
249
+
250
+ for func_name , module_hash in func_name_to_module_hash .items ():
251
+ link_name = self ._generate_link_name (func_name , module_hash )
252
+ self .func_name_to_link_names .setdefault (func_name , []).append (link_name )
253
+ self .link_name_to_module_hash [link_name ] = module_hash
221
254
222
255
self .moduleManifestsLoaded .add (moduleHash )
223
256
224
- def writeModuleToDisk (self , binarySharedObject , nameToTypedCallTarget , submodules , dependencyEdgelist ):
257
+ def writeModuleToDisk (self , binarySharedObject , hashToUse , nameToTypedCallTarget , submodules , dependencyEdgelist ):
225
258
"""Write out a disk representation of this module.
226
259
227
260
This includes writing both the shared object, a manifest of the function names
@@ -235,7 +268,6 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
235
268
to interact with the compiler cache simultaneously without relying on
236
269
individual file-level locking.
237
270
"""
238
- hashToUse = SerializationContext ().sha_hash (str (uuid .uuid4 ())).hexdigest
239
271
240
272
targetDir = os .path .join (
241
273
self .cacheDir ,
@@ -264,23 +296,20 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
264
296
for sourceName in manifest :
265
297
f .write (sourceName + "\n " )
266
298
267
- # write the type manifest
268
299
with open (os .path .join (tempTargetDir , "type_manifest.dat" ), "wb" ) as f :
269
300
f .write (SerializationContext ().serialize (nameToTypedCallTarget ))
270
301
271
- # write the nativetype manifest
272
302
with open (os .path .join (tempTargetDir , "native_type_manifest.dat" ), "wb" ) as f :
273
303
f .write (SerializationContext ().serialize (binarySharedObject .functionTypes ))
274
304
275
- # write the type manifest
276
305
with open (os .path .join (tempTargetDir , "globals_manifest.dat" ), "wb" ) as f :
277
306
f .write (SerializationContext ().serialize (binarySharedObject .serializedGlobalVariableDefinitions ))
278
307
279
308
with open (os .path .join (tempTargetDir , "submodules.dat" ), "wb" ) as f :
280
309
f .write (SerializationContext ().serialize (ListOf (str )(submodules ), ListOf (str )))
281
310
282
311
with open (os .path .join (tempTargetDir , "function_dependencies.dat" ), "wb" ) as f :
283
- f .write (SerializationContext ().serialize (dependencyEdgelist )) # might need a listof
312
+ f .write (SerializationContext ().serialize (dependencyEdgelist ))
284
313
285
314
with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
286
315
f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
@@ -293,14 +322,15 @@ def writeModuleToDisk(self, binarySharedObject, nameToTypedCallTarget, submodule
293
322
else :
294
323
shutil .rmtree (tempTargetDir )
295
324
296
- return targetDir , hashToUse
325
+ return targetDir
297
326
298
- def function_pointer_by_name (self , linkName ):
299
- moduleHash = self .nameToModuleHash .get (linkName )
327
+ def function_pointer_by_name (self , func_name ):
328
+ linkName = self ._select_link_name (func_name )
329
+ moduleHash = self .link_name_to_module_hash .get (linkName )
300
330
if moduleHash is None :
301
331
raise Exception ("Can't find a module for " + linkName )
302
332
303
333
if moduleHash not in self .loadedBinarySharedObjects :
304
334
self .loadForSymbol (linkName )
305
335
306
- return self .loadedBinarySharedObjects [moduleHash ].functionPointers [linkName ]
336
+ return self .loadedBinarySharedObjects [moduleHash ].functionPointers [func_name ]
0 commit comments