1+ import os
12import pathlib
3+ from collections .abc import Iterable
24
35import yaml
46
5- from substrait import proto
7+ from substrait .gen .proto .type_pb2 import Type as SubstraitType
8+ from substrait .gen .proto .extensions .extensions_pb2 import (
9+ SimpleExtensionURI ,
10+ SimpleExtensionDeclaration ,
11+ )
12+
13+
14+ class RegisteredSubstraitFunction :
15+ """A Substrait function loaded from an extension file.
16+
17+ The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction
18+ and will use them to generate the necessary extension URIs and extensions.
19+ """
20+
21+ def __init__ (self , signature : str , function_anchor : int | None , impl : dict ):
22+ self .signature = signature
23+ self .function_anchor = function_anchor
24+ self .variadic = impl .get ("variadic" , False )
25+
26+ if "return" in impl :
27+ self .return_type = self ._type_from_name (impl ["return" ])
28+ else :
29+ # We do always need a return type
30+ # to know which type to propagate up to the invoker
31+ _ , argtypes = FunctionsCatalog .parse_signature (signature )
32+ # TODO: Is this the right way to handle this?
33+ self .return_type = self ._type_from_name (argtypes [0 ])
34+
35+ @property
36+ def name (self ) -> str :
37+ name , _ = FunctionsCatalog .parse_signature (self .signature )
38+ return name
39+
40+ @property
41+ def arguments (self ) -> list [str ]:
42+ _ , argtypes = FunctionsCatalog .parse_signature (self .signature )
43+ return argtypes
44+
45+ @property
46+ def arguments_type (self ) -> list [SubstraitType | None ]:
47+ return [self ._type_from_name (arg ) for arg in self .arguments ]
48+
49+ def _type_from_name (self , typename : str ) -> SubstraitType | None :
50+ nullable = False
51+ if typename .endswith ("?" ):
52+ nullable = True
53+
54+ typename = typename .strip ("?" )
55+ if typename in ("any" , "any1" ):
56+ return None
57+
58+ if typename == "boolean" :
59+ # For some reason boolean is an exception to the naming convention
60+ typename = "bool"
61+
62+ try :
63+ type_descriptor = SubstraitType .DESCRIPTOR .fields_by_name [
64+ typename
65+ ].message_type
66+ except KeyError :
67+ # TODO: improve resolution of complext type like LIST?<any>
68+ print ("Unsupported type" , typename )
69+ return None
70+
71+ type_class = getattr (SubstraitType , type_descriptor .name )
72+ nullability = (
73+ SubstraitType .Nullability .NULLABILITY_REQUIRED
74+ if not nullable
75+ else SubstraitType .Nullability .NULLABILITY_NULLABLE
76+ )
77+ return SubstraitType (** {typename : type_class (nullability = nullability )})
678
779
880class FunctionsCatalog :
@@ -32,20 +104,21 @@ class FunctionsCatalog:
32104 )
33105
34106 def __init__ (self ):
35- self ._registered_extensions = {}
107+ self ._substrait_extension_uris = {}
108+ self ._substrait_extension_functions = {}
36109 self ._functions = {}
37- self ._functions_return_type = {}
38110
39- def load_standard_extensions (self , dirpath ):
111+ def load_standard_extensions (self , dirpath : str | os .PathLike ):
112+ """Load all standard substrait extensions from the target directory."""
40113 for ext in self .STANDARD_EXTENSIONS :
41114 self .load (dirpath , ext )
42115
43- def load (self , dirpath , filename ):
116+ def load (self , dirpath : str | os .PathLike , filename : str ):
117+ """Load an extension from a YAML file in a target directory."""
44118 with open (pathlib .Path (dirpath ) / filename .strip ("/" )) as f :
45119 sections = yaml .safe_load (f )
46120
47- loaded_functions = set ()
48- functions_return_type = {}
121+ loaded_functions = {}
49122 for functions in sections .values ():
50123 for function in functions :
51124 function_name = function ["name" ]
@@ -56,100 +129,80 @@ def load(self, dirpath, filename):
56129 t .get ("value" , "unknown" ).strip ("?" )
57130 for t in impl .get ("args" , [])
58131 ]
59- if impl .get ("variadic" , False ):
60- # TODO: Variadic functions.
61- argtypes *= 2
62-
63132 if not argtypes :
64133 signature = function_name
65134 else :
66135 signature = f"{ function_name } :{ '_' .join (argtypes )} "
67- loaded_functions .add (signature )
68- print ("Loaded function" , signature )
69- functions_return_type [signature ] = self ._type_from_name (
70- impl ["return" ]
136+ loaded_functions [signature ] = RegisteredSubstraitFunction (
137+ signature , None , impl
71138 )
72139
73- self ._register_extensions (filename , loaded_functions , functions_return_type )
140+ self ._register_extensions (filename , loaded_functions )
74141
75142 def _register_extensions (
76- self , extension_uri , loaded_functions , functions_return_type
143+ self ,
144+ extension_uri : str ,
145+ loaded_functions : dict [str , RegisteredSubstraitFunction ],
77146 ):
78- if extension_uri not in self ._registered_extensions :
79- ext_anchor_id = len (self ._registered_extensions ) + 1
80- self ._registered_extensions [extension_uri ] = proto . SimpleExtensionURI (
147+ if extension_uri not in self ._substrait_extension_uris :
148+ ext_anchor_id = len (self ._substrait_extension_uris ) + 1
149+ self ._substrait_extension_uris [extension_uri ] = SimpleExtensionURI (
81150 extension_uri_anchor = ext_anchor_id , uri = extension_uri
82151 )
83152
84- for function in loaded_functions :
85- if function in self ._functions :
153+ for signature , registered_function in loaded_functions . items () :
154+ if signature in self ._substrait_extension_functions :
86155 extensions_by_anchor = self .extension_uris_by_anchor
87- existing_function = self ._functions [ function ]
156+ existing_function = self ._substrait_extension_functions [ signature ]
88157 function_extension = extensions_by_anchor [
89158 existing_function .extension_uri_reference
90159 ].uri
91160 raise ValueError (
92161 f"Duplicate function definition: { existing_function .name } from { extension_uri } , already loaded from { function_extension } "
93162 )
94- extension_anchor = self ._registered_extensions [
163+ extension_anchor = self ._substrait_extension_uris [
95164 extension_uri
96165 ].extension_uri_anchor
97- function_anchor = len (self ._functions ) + 1
98- self ._functions [ function ] = (
99- proto . SimpleExtensionDeclaration .ExtensionFunction (
166+ function_anchor = len (self ._substrait_extension_functions ) + 1
167+ self ._substrait_extension_functions [ signature ] = (
168+ SimpleExtensionDeclaration .ExtensionFunction (
100169 extension_uri_reference = extension_anchor ,
101- name = function ,
170+ name = signature ,
102171 function_anchor = function_anchor ,
103172 )
104173 )
105- self ._functions_return_type [function ] = functions_return_type [function ]
106-
107- def _type_from_name (self , typename ):
108- nullable = False
109- if typename .endswith ("?" ):
110- nullable = True
111-
112- typename = typename .strip ("?" )
113- if typename in ("any" , "any1" ):
114- return None
115-
116- if typename == "boolean" :
117- # For some reason boolean is an exception to the naming convention
118- typename = "bool"
119-
120- try :
121- type_descriptor = proto .Type .DESCRIPTOR .fields_by_name [
122- typename
123- ].message_type
124- except KeyError :
125- # TODO: improve resolution of complext type like LIST?<any>
126- print ("Unsupported type" , typename )
127- return None
128-
129- type_class = getattr (proto .Type , type_descriptor .name )
130- nullability = (
131- proto .Type .Nullability .NULLABILITY_REQUIRED
132- if not nullable
133- else proto .Type .Nullability .NULLABILITY_NULLABLE
134- )
135- return proto .Type (** {typename : type_class (nullability = nullability )})
174+ registered_function .function_anchor = function_anchor
175+ self ._functions .setdefault (registered_function .name , []).append (
176+ registered_function
177+ )
136178
137179 @property
138- def extension_uris_by_anchor (self ):
180+ def extension_uris_by_anchor (self ) -> dict [ int , SimpleExtensionURI ] :
139181 return {
140182 ext .extension_uri_anchor : ext
141- for ext in self ._registered_extensions .values ()
183+ for ext in self ._substrait_extension_uris .values ()
142184 }
143185
144186 @property
145- def extension_uris (self ):
146- return list (self ._registered_extensions .values ())
187+ def extension_uris (self ) -> list [ SimpleExtensionURI ] :
188+ return list (self ._substrait_extension_uris .values ())
147189
148190 @property
149- def extensions (self ):
150- return list (self ._functions .values ())
191+ def extensions_functions (
192+ self ,
193+ ) -> list [SimpleExtensionDeclaration .ExtensionFunction ]:
194+ return list (self ._substrait_extension_functions .values ())
195+
196+ @classmethod
197+ def make_signature (
198+ cls , function_name : str , proto_argtypes : Iterable [SubstraitType ]
199+ ):
200+ """Create a function signature from a function name and substrait types.
201+
202+ The signature is generated according to Function Signature Compound Names
203+ as described in the Substrait documentation.
204+ """
151205
152- def signature (self , function_name , proto_argtypes ):
153206 def _normalize_arg_types (argtypes ):
154207 for argtype in argtypes :
155208 kind = argtype .WhichOneof ("kind" )
@@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes):
160213
161214 return f"{ function_name } :{ '_' .join (_normalize_arg_types (proto_argtypes ))} "
162215
163- def function_anchor (self , function ):
164- return self ._functions [function ].function_anchor
216+ @classmethod
217+ def parse_signature (cls , signature : str ) -> tuple [str , list [str ]]:
218+ """Parse a function signature and returns name and type names"""
219+ try :
220+ function_name , signature_args = signature .split (":" )
221+ except ValueError :
222+ function_name = signature
223+ argtypes = []
224+ else :
225+ argtypes = signature_args .split ("_" )
226+ return function_name , argtypes
165227
166- def function_return_type (self , function ):
167- return self ._functions_return_type [function ]
228+ def extensions_for_functions (
229+ self , function_signatures : Iterable [str ]
230+ ) -> tuple [list [SimpleExtensionURI ], list [SimpleExtensionDeclaration ]]:
231+ """Given a set of function signatures, return the necessary extensions.
168232
169- def extensions_for_functions (self , functions ):
233+ The function will return the URIs of the extensions and the extension
234+ that have to be declared in the plan to use the functions.
235+ """
170236 uris_anchors = set ()
171237 extensions = []
172- for f in functions :
173- ext = self ._functions [f ]
174- if not ext .extension_uri_reference :
175- # Built-in function
176- continue
238+ for f in function_signatures :
239+ ext = self ._substrait_extension_functions [f ]
177240 uris_anchors .add (ext .extension_uri_reference )
178- extensions .append (proto . SimpleExtensionDeclaration (extension_function = ext ))
241+ extensions .append (SimpleExtensionDeclaration (extension_function = ext ))
179242
180243 uris_by_anchor = self .extension_uris_by_anchor
181244 extension_uris = [uris_by_anchor [uri_anchor ] for uri_anchor in uris_anchors ]
182245 return extension_uris , extensions
246+
247+ def lookup_function (self , signature : str ) -> RegisteredSubstraitFunction | None :
248+ """Given the signature of a function invocation, return the matching function."""
249+ function_name , invocation_argtypes = self .parse_signature (signature )
250+
251+ functions = self ._functions .get (function_name )
252+ if not functions :
253+ # No function with such a name at all.
254+ return None
255+
256+ is_variadic = functions [0 ].variadic
257+ if is_variadic :
258+ # If it's variadic we care about only the first parameter.
259+ invocation_argtypes = invocation_argtypes [:1 ]
260+
261+ found_function = None
262+ for function in functions :
263+ accepted_function_arguments = function .arguments
264+ for argidx , argtype in enumerate (invocation_argtypes ):
265+ try :
266+ accepted_argument = accepted_function_arguments [argidx ]
267+ except IndexError :
268+ # More arguments than available were provided
269+ break
270+ if accepted_argument != argtype and accepted_argument not in (
271+ "any" ,
272+ "any1" ,
273+ ):
274+ break
275+ else :
276+ if argidx < len (accepted_function_arguments ) - 1 :
277+ # Not enough arguments were provided
278+ remainder = accepted_function_arguments [argidx + 1 :]
279+ if all (arg .endswith ("?" ) for arg in remainder ):
280+ # All remaining arguments are optional
281+ found_function = function
282+ else :
283+ found_function = function
284+
285+ return found_function
0 commit comments