@@ -35,57 +35,35 @@ use std::sync::Arc;
3535/// functions you supply such name, type signature, return type, and actual
3636/// implementation.
3737///
38- ///
3938/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
4039///
4140/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
4241///
42+ /// # API Note
43+ ///
44+ /// This is a separate struct from `ScalarUDFImpl` to maintain backwards
45+ /// compatibility with the older API.
46+ ///
4347/// [`create_udf`]: crate::expr_fn::create_udf
4448/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
4549/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
46- #[ derive( Clone ) ]
50+ #[ derive( Debug , Clone ) ]
4751pub struct ScalarUDF {
48- /// The name of the function
49- name : String ,
50- /// The signature (the types of arguments that are supported)
51- signature : Signature ,
52- /// Function that returns the return type given the argument types
53- return_type : ReturnTypeFunction ,
54- /// actual implementation
55- ///
56- /// The fn param is the wrapped function but be aware that the function will
57- /// be passed with the slice / vec of columnar values (either scalar or array)
58- /// with the exception of zero param function, where a singular element vec
59- /// will be passed. In that case the single element is a null array to indicate
60- /// the batch's row count (so that the generative zero-argument function can know
61- /// the result array size).
62- fun : ScalarFunctionImplementation ,
63- /// Optional aliases for the function. This list should NOT include the value of `name` as well
64- aliases : Vec < String > ,
65- }
66-
67- impl Debug for ScalarUDF {
68- fn fmt ( & self , f : & mut Formatter ) -> fmt:: Result {
69- f. debug_struct ( "ScalarUDF" )
70- . field ( "name" , & self . name )
71- . field ( "signature" , & self . signature )
72- . field ( "fun" , & "<FUNC>" )
73- . finish ( )
74- }
52+ inner : Arc < dyn ScalarUDFImpl > ,
7553}
7654
7755impl PartialEq for ScalarUDF {
7856 fn eq ( & self , other : & Self ) -> bool {
79- self . name == other. name && self . signature == other. signature
57+ self . name ( ) == other. name ( ) && self . signature ( ) == other. signature ( )
8058 }
8159}
8260
8361impl Eq for ScalarUDF { }
8462
8563impl std:: hash:: Hash for ScalarUDF {
8664 fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
87- self . name . hash ( state) ;
88- self . signature . hash ( state) ;
65+ self . name ( ) . hash ( state) ;
66+ self . signature ( ) . hash ( state) ;
8967 }
9068}
9169
@@ -101,51 +79,37 @@ impl ScalarUDF {
10179 return_type : & ReturnTypeFunction ,
10280 fun : & ScalarFunctionImplementation ,
10381 ) -> Self {
104- Self {
82+ Self :: new_from_impl ( ScalarUdfLegacyWrapper {
10583 name : name. to_owned ( ) ,
10684 signature : signature. clone ( ) ,
10785 return_type : return_type. clone ( ) ,
10886 fun : fun. clone ( ) ,
109- aliases : vec ! [ ] ,
110- }
87+ } )
11188 }
11289
11390 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
11491 ///
11592 /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
11693 pub fn new_from_impl < F > ( fun : F ) -> ScalarUDF
11794 where
118- F : ScalarUDFImpl + Send + Sync + ' static ,
95+ F : ScalarUDFImpl + ' static ,
11996 {
120- // TODO change the internal implementation to use the trait object
121- let arc_fun = Arc :: new ( fun) ;
122- let captured_self = arc_fun. clone ( ) ;
123- let return_type: ReturnTypeFunction = Arc :: new ( move |arg_types| {
124- let return_type = captured_self. return_type ( arg_types) ?;
125- Ok ( Arc :: new ( return_type) )
126- } ) ;
127-
128- let captured_self = arc_fun. clone ( ) ;
129- let func: ScalarFunctionImplementation =
130- Arc :: new ( move |args| captured_self. invoke ( args) ) ;
131-
13297 Self {
133- name : arc_fun. name ( ) . to_string ( ) ,
134- signature : arc_fun. signature ( ) . clone ( ) ,
135- return_type : return_type. clone ( ) ,
136- fun : func,
137- aliases : arc_fun. aliases ( ) . to_vec ( ) ,
98+ inner : Arc :: new ( fun) ,
13899 }
139100 }
140101
141- /// Adds additional names that can be used to invoke this function, in addition to `name`
142- pub fn with_aliases (
143- mut self ,
144- aliases : impl IntoIterator < Item = & ' static str > ,
145- ) -> Self {
146- self . aliases
147- . extend ( aliases. into_iter ( ) . map ( |s| s. to_string ( ) ) ) ;
148- self
102+ /// Return the underlying [`ScalarUDFImpl`] trait object for this function
103+ pub fn inner ( & self ) -> Arc < dyn ScalarUDFImpl > {
104+ self . inner . clone ( )
105+ }
106+
107+ /// Adds additional names that can be used to invoke this function, in
108+ /// addition to `name`
109+ ///
110+ /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
111+ pub fn with_aliases ( self , aliases : impl IntoIterator < Item = & ' static str > ) -> Self {
112+ Self :: new_from_impl ( AliasedScalarUDFImpl :: new ( self , aliases) )
149113 }
150114
151115 /// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -159,31 +123,46 @@ impl ScalarUDF {
159123 ) )
160124 }
161125
162- /// Returns this function's name
126+ /// Returns this function's name.
127+ ///
128+ /// See [`ScalarUDFImpl::name`] for more details.
163129 pub fn name ( & self ) -> & str {
164- & self . name
130+ self . inner . name ( )
165131 }
166132
167- /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
133+ /// Returns the aliases for this function.
134+ ///
135+ /// See [`ScalarUDF::with_aliases`] for more details
168136 pub fn aliases ( & self ) -> & [ String ] {
169- & self . aliases
137+ self . inner . aliases ( )
170138 }
171139
172- /// Returns this function's [`Signature`] (what input types are accepted)
140+ /// Returns this function's [`Signature`] (what input types are accepted).
141+ ///
142+ /// See [`ScalarUDFImpl::signature`] for more details.
173143 pub fn signature ( & self ) -> & Signature {
174- & self . signature
144+ self . inner . signature ( )
175145 }
176146
177- /// The datatype this function returns given the input argument input types
147+ /// The datatype this function returns given the input argument input types.
148+ ///
149+ /// See [`ScalarUDFImpl::return_type`] for more details.
178150 pub fn return_type ( & self , args : & [ DataType ] ) -> Result < DataType > {
179- // Old API returns an Arc of the datatype for some reason
180- let res = ( self . return_type ) ( args) ?;
181- Ok ( res. as_ref ( ) . clone ( ) )
151+ self . inner . return_type ( args)
152+ }
153+
154+ /// Invoke the function on `args`, returning the appropriate result.
155+ ///
156+ /// See [`ScalarUDFImpl::invoke`] for more details.
157+ pub fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
158+ self . inner . invoke ( args)
182159 }
183160
184- /// Return an [`Arc`] to the function implementation
161+ /// Returns a `ScalarFunctionImplementation` that can invoke the function
162+ /// during execution
185163 pub fn fun ( & self ) -> ScalarFunctionImplementation {
186- self . fun . clone ( )
164+ let captured = self . inner . clone ( ) ;
165+ Arc :: new ( move |args| captured. invoke ( args) )
187166 }
188167}
189168
@@ -246,7 +225,7 @@ where
246225/// // Call the function `add_one(col)`
247226/// let expr = add_one.call(vec![col("a")]);
248227/// ```
249- pub trait ScalarUDFImpl {
228+ pub trait ScalarUDFImpl : Debug + Send + Sync {
250229 /// Returns this object as an [`Any`] trait object
251230 fn as_any ( & self ) -> & dyn Any ;
252231
@@ -292,3 +271,105 @@ pub trait ScalarUDFImpl {
292271 & [ ]
293272 }
294273}
274+
275+ /// ScalarUDF that adds an alias to the underlying function. It is better to
276+ /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
277+ #[ derive( Debug ) ]
278+ struct AliasedScalarUDFImpl {
279+ inner : ScalarUDF ,
280+ aliases : Vec < String > ,
281+ }
282+
283+ impl AliasedScalarUDFImpl {
284+ pub fn new (
285+ inner : ScalarUDF ,
286+ new_aliases : impl IntoIterator < Item = & ' static str > ,
287+ ) -> Self {
288+ let mut aliases = inner. aliases ( ) . to_vec ( ) ;
289+ aliases. extend ( new_aliases. into_iter ( ) . map ( |s| s. to_string ( ) ) ) ;
290+
291+ Self { inner, aliases }
292+ }
293+ }
294+
295+ impl ScalarUDFImpl for AliasedScalarUDFImpl {
296+ fn as_any ( & self ) -> & dyn Any {
297+ self
298+ }
299+ fn name ( & self ) -> & str {
300+ self . inner . name ( )
301+ }
302+
303+ fn signature ( & self ) -> & Signature {
304+ self . inner . signature ( )
305+ }
306+
307+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
308+ self . inner . return_type ( arg_types)
309+ }
310+
311+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
312+ self . inner . invoke ( args)
313+ }
314+
315+ fn aliases ( & self ) -> & [ String ] {
316+ & self . aliases
317+ }
318+ }
319+
320+ /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers of the older API
321+ /// (see https://github.com/apache/arrow-datafusion/pull/8578)
322+ struct ScalarUdfLegacyWrapper {
323+ /// The name of the function
324+ name : String ,
325+ /// The signature (the types of arguments that are supported)
326+ signature : Signature ,
327+ /// Function that returns the return type given the argument types
328+ return_type : ReturnTypeFunction ,
329+ /// actual implementation
330+ ///
331+ /// The fn param is the wrapped function but be aware that the function will
332+ /// be passed with the slice / vec of columnar values (either scalar or array)
333+ /// with the exception of zero param function, where a singular element vec
334+ /// will be passed. In that case the single element is a null array to indicate
335+ /// the batch's row count (so that the generative zero-argument function can know
336+ /// the result array size).
337+ fun : ScalarFunctionImplementation ,
338+ }
339+
340+ impl Debug for ScalarUdfLegacyWrapper {
341+ fn fmt ( & self , f : & mut Formatter ) -> fmt:: Result {
342+ f. debug_struct ( "ScalarUDF" )
343+ . field ( "name" , & self . name )
344+ . field ( "signature" , & self . signature )
345+ . field ( "fun" , & "<FUNC>" )
346+ . finish ( )
347+ }
348+ }
349+
350+ impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
351+ fn as_any ( & self ) -> & dyn Any {
352+ self
353+ }
354+ fn name ( & self ) -> & str {
355+ & self . name
356+ }
357+
358+ fn signature ( & self ) -> & Signature {
359+ & self . signature
360+ }
361+
362+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
363+ // Old API returns an Arc of the datatype for some reason
364+ let res = ( self . return_type ) ( arg_types) ?;
365+ Ok ( res. as_ref ( ) . clone ( ) )
366+ }
367+
368+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
369+ ( self . fun ) ( args)
370+ }
371+
372+ fn aliases ( & self ) -> & [ String ] {
373+ & [ ]
374+ }
375+ }
0 commit comments