@@ -35,57 +35,35 @@ use std::sync::Arc;
3535/// functions you supply such name, type signature, return type, and actual
3636/// implementation.
3737///
38- ///
3938/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
4039///
4140/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
4241///
42+ /// # API Note
43+ ///
44+ /// This is a separate struct from `ScalarUDFImpl` to maintain backwards
45+ /// compatibility with the older API.
46+ ///
4347/// [`create_udf`]: crate::expr_fn::create_udf
4448/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
4549/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
46- #[ derive( Clone ) ]
50+ #[ derive( Debug , Clone ) ]
4751pub struct ScalarUDF {
48- /// The name of the function
49- name : String ,
50- /// The signature (the types of arguments that are supported)
51- signature : Signature ,
52- /// Function that returns the return type given the argument types
53- return_type : ReturnTypeFunction ,
54- /// actual implementation
55- ///
56- /// The fn param is the wrapped function but be aware that the function will
57- /// be passed with the slice / vec of columnar values (either scalar or array)
58- /// with the exception of zero param function, where a singular element vec
59- /// will be passed. In that case the single element is a null array to indicate
60- /// the batch's row count (so that the generative zero-argument function can know
61- /// the result array size).
62- fun : ScalarFunctionImplementation ,
63- /// Optional aliases for the function. This list should NOT include the value of `name` as well
64- aliases : Vec < String > ,
65- }
66-
67- impl Debug for ScalarUDF {
68- fn fmt ( & self , f : & mut Formatter ) -> fmt:: Result {
69- f. debug_struct ( "ScalarUDF" )
70- . field ( "name" , & self . name )
71- . field ( "signature" , & self . signature )
72- . field ( "fun" , & "<FUNC>" )
73- . finish ( )
74- }
52+ inner : Arc < dyn ScalarUDFImpl > ,
7553}
7654
7755impl PartialEq for ScalarUDF {
7856 fn eq ( & self , other : & Self ) -> bool {
79- self . name == other. name && self . signature == other. signature
57+ self . name ( ) == other. name ( ) && self . signature ( ) == other. signature ( )
8058 }
8159}
8260
8361impl Eq for ScalarUDF { }
8462
8563impl std:: hash:: Hash for ScalarUDF {
8664 fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
87- self . name . hash ( state) ;
88- self . signature . hash ( state) ;
65+ self . name ( ) . hash ( state) ;
66+ self . signature ( ) . hash ( state) ;
8967 }
9068}
9169
@@ -101,51 +79,37 @@ impl ScalarUDF {
10179 return_type : & ReturnTypeFunction ,
10280 fun : & ScalarFunctionImplementation ,
10381 ) -> Self {
104- Self {
82+ Self :: new_from_impl ( ScalarUdfLegacyWrapper {
10583 name : name. to_owned ( ) ,
10684 signature : signature. clone ( ) ,
10785 return_type : return_type. clone ( ) ,
10886 fun : fun. clone ( ) ,
109- aliases : vec ! [ ] ,
110- }
87+ } )
11188 }
11289
11390 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
11491 ///
11592 /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
11693 pub fn new_from_impl < F > ( fun : F ) -> ScalarUDF
11794 where
118- F : ScalarUDFImpl + Send + Sync + ' static ,
95+ F : ScalarUDFImpl + ' static ,
11996 {
120- // TODO change the internal implementation to use the trait object
121- let arc_fun = Arc :: new ( fun) ;
122- let captured_self = arc_fun. clone ( ) ;
123- let return_type: ReturnTypeFunction = Arc :: new ( move |arg_types| {
124- let return_type = captured_self. return_type ( arg_types) ?;
125- Ok ( Arc :: new ( return_type) )
126- } ) ;
127-
128- let captured_self = arc_fun. clone ( ) ;
129- let func: ScalarFunctionImplementation =
130- Arc :: new ( move |args| captured_self. invoke ( args) ) ;
131-
13297 Self {
133- name : arc_fun. name ( ) . to_string ( ) ,
134- signature : arc_fun. signature ( ) . clone ( ) ,
135- return_type : return_type. clone ( ) ,
136- fun : func,
137- aliases : arc_fun. aliases ( ) . to_vec ( ) ,
98+ inner : Arc :: new ( fun) ,
13899 }
139100 }
140101
141- /// Adds additional names that can be used to invoke this function, in addition to `name`
142- pub fn with_aliases (
143- mut self ,
144- aliases : impl IntoIterator < Item = & ' static str > ,
145- ) -> Self {
146- self . aliases
147- . extend ( aliases. into_iter ( ) . map ( |s| s. to_string ( ) ) ) ;
148- self
102+ /// Return the underlying [`ScalarUDFImpl`] trait object for this function
103+ pub fn inner ( & self ) -> Arc < dyn ScalarUDFImpl > {
104+ self . inner . clone ( )
105+ }
106+
107+ /// Adds additional names that can be used to invoke this function, in
108+ /// addition to `name`
109+ ///
110+ /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
111+ pub fn with_aliases ( self , aliases : impl IntoIterator < Item = & ' static str > ) -> Self {
112+ Self :: new_from_impl ( AliasedScalarUDFImpl :: new ( self , aliases) )
149113 }
150114
151115 /// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -159,31 +123,46 @@ impl ScalarUDF {
159123 ) )
160124 }
161125
162- /// Returns this function's name
126+ /// Returns this function's name.
127+ ///
128+ /// See [`ScalarUDFImpl::name`] for more details.
163129 pub fn name ( & self ) -> & str {
164- & self . name
130+ self . inner . name ( )
165131 }
166132
167- /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
133+ /// Returns the aliases for this function.
134+ ///
135+ /// See [`ScalarUDF::with_aliases`] for more details
168136 pub fn aliases ( & self ) -> & [ String ] {
169- & self . aliases
137+ self . inner . aliases ( )
170138 }
171139
172- /// Returns this function's [`Signature`] (what input types are accepted)
140+ /// Returns this function's [`Signature`] (what input types are accepted).
141+ ///
142+ /// See [`ScalarUDFImpl::signature`] for more details.
173143 pub fn signature ( & self ) -> & Signature {
174- & self . signature
144+ self . inner . signature ( )
175145 }
176146
177- /// The datatype this function returns given the input argument input types
147+ /// The datatype this function returns given the input argument input types.
148+ ///
149+ /// See [`ScalarUDFImpl::return_type`] for more details.
178150 pub fn return_type ( & self , args : & [ DataType ] ) -> Result < DataType > {
179- // Old API returns an Arc of the datatype for some reason
180- let res = ( self . return_type ) ( args) ?;
181- Ok ( res. as_ref ( ) . clone ( ) )
151+ self . inner . return_type ( args)
152+ }
153+
154+ /// Invoke the function on `args`, returning the appropriate result.
155+ ///
156+ /// See [`ScalarUDFImpl::invoke`] for more details.
157+ pub fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
158+ self . inner . invoke ( args)
182159 }
183160
184- /// Return an [`Arc`] to the function implementation
161+ /// Returns a `ScalarFunctionImplementation` that can invoke the function
162+ /// during execution
185163 pub fn fun ( & self ) -> ScalarFunctionImplementation {
186- self . fun . clone ( )
164+ let captured = self . inner . clone ( ) ;
165+ Arc :: new ( move |args| captured. invoke ( args) )
187166 }
188167}
189168
@@ -213,6 +192,7 @@ where
213192/// # use datafusion_common::{DataFusionError, plan_err, Result};
214193/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
215194/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
195+ /// #[derive(Debug)]
216196/// struct AddOne {
217197/// signature: Signature
218198/// };
@@ -246,7 +226,7 @@ where
246226/// // Call the function `add_one(col)`
247227/// let expr = add_one.call(vec![col("a")]);
248228/// ```
249- pub trait ScalarUDFImpl {
229+ pub trait ScalarUDFImpl : Debug + Send + Sync {
250230 /// Returns this object as an [`Any`] trait object
251231 fn as_any ( & self ) -> & dyn Any ;
252232
@@ -292,3 +272,106 @@ pub trait ScalarUDFImpl {
292272 & [ ]
293273 }
294274}
275+
276+ /// ScalarUDF that adds an alias to the underlying function. It is better to
277+ /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
278+ #[ derive( Debug ) ]
279+ struct AliasedScalarUDFImpl {
280+ inner : ScalarUDF ,
281+ aliases : Vec < String > ,
282+ }
283+
284+ impl AliasedScalarUDFImpl {
285+ pub fn new (
286+ inner : ScalarUDF ,
287+ new_aliases : impl IntoIterator < Item = & ' static str > ,
288+ ) -> Self {
289+ let mut aliases = inner. aliases ( ) . to_vec ( ) ;
290+ aliases. extend ( new_aliases. into_iter ( ) . map ( |s| s. to_string ( ) ) ) ;
291+
292+ Self { inner, aliases }
293+ }
294+ }
295+
296+ impl ScalarUDFImpl for AliasedScalarUDFImpl {
297+ fn as_any ( & self ) -> & dyn Any {
298+ self
299+ }
300+ fn name ( & self ) -> & str {
301+ self . inner . name ( )
302+ }
303+
304+ fn signature ( & self ) -> & Signature {
305+ self . inner . signature ( )
306+ }
307+
308+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
309+ self . inner . return_type ( arg_types)
310+ }
311+
312+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
313+ self . inner . invoke ( args)
314+ }
315+
316+ fn aliases ( & self ) -> & [ String ] {
317+ & self . aliases
318+ }
319+ }
320+
321+ /// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers
322+ /// of the older API (see <https://github.com/apache/arrow-datafusion/pull/8578>
323+ /// for more details)
324+ struct ScalarUdfLegacyWrapper {
325+ /// The name of the function
326+ name : String ,
327+ /// The signature (the types of arguments that are supported)
328+ signature : Signature ,
329+ /// Function that returns the return type given the argument types
330+ return_type : ReturnTypeFunction ,
331+ /// actual implementation
332+ ///
333+ /// The fn param is the wrapped function but be aware that the function will
334+ /// be passed with the slice / vec of columnar values (either scalar or array)
335+ /// with the exception of zero param function, where a singular element vec
336+ /// will be passed. In that case the single element is a null array to indicate
337+ /// the batch's row count (so that the generative zero-argument function can know
338+ /// the result array size).
339+ fun : ScalarFunctionImplementation ,
340+ }
341+
342+ impl Debug for ScalarUdfLegacyWrapper {
343+ fn fmt ( & self , f : & mut Formatter ) -> fmt:: Result {
344+ f. debug_struct ( "ScalarUDF" )
345+ . field ( "name" , & self . name )
346+ . field ( "signature" , & self . signature )
347+ . field ( "fun" , & "<FUNC>" )
348+ . finish ( )
349+ }
350+ }
351+
352+ impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
353+ fn as_any ( & self ) -> & dyn Any {
354+ self
355+ }
356+ fn name ( & self ) -> & str {
357+ & self . name
358+ }
359+
360+ fn signature ( & self ) -> & Signature {
361+ & self . signature
362+ }
363+
364+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
365+ // Old API returns an Arc of the datatype for some reason
366+ let res = ( self . return_type ) ( arg_types) ?;
367+ Ok ( res. as_ref ( ) . clone ( ) )
368+ }
369+
370+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
371+ ( self . fun ) ( args)
372+ }
373+
374+ fn aliases ( & self ) -> & [ String ] {
375+ & [ ]
376+ }
377+ }
0 commit comments