Skip to content

Commit 746988a

Browse files
authored
Implement ScalarUDF in terms of ScalarUDFImpl trait (#8713)
1 parent cc42894 commit 746988a

File tree

6 files changed

+171
-75
lines changed

6 files changed

+171
-75
lines changed

datafusion-examples/examples/advanced_udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use std::sync::Arc;
4040
/// the power of the second argument `a^b`.
4141
///
4242
/// To do so, we must implement the `ScalarUDFImpl` trait.
43+
#[derive(Debug, Clone)]
4344
struct PowUdf {
4445
signature: Signature,
4546
aliases: Vec<String>,

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,7 @@ mod test {
19481948
);
19491949

19501950
// UDF
1951+
#[derive(Debug)]
19511952
struct TestScalarUDF {
19521953
signature: Signature,
19531954
}

datafusion/expr/src/expr_fn.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,16 @@ pub struct SimpleScalarUDF {
984984
fun: ScalarFunctionImplementation,
985985
}
986986

987+
impl Debug for SimpleScalarUDF {
988+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
989+
f.debug_struct("ScalarUDF")
990+
.field("name", &self.name)
991+
.field("signature", &self.signature)
992+
.field("fun", &"<FUNC>")
993+
.finish()
994+
}
995+
}
996+
987997
impl SimpleScalarUDF {
988998
/// Create a new `SimpleScalarUDF` from a name, input types, return type and
989999
/// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility

datafusion/expr/src/udf.rs

Lines changed: 157 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -35,57 +35,35 @@ use std::sync::Arc;
3535
/// functions you supply such name, type signature, return type, and actual
3636
/// implementation.
3737
///
38-
///
3938
/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
4039
///
4140
/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
4241
///
42+
/// # API Note
43+
///
44+
/// This is a separate struct from `ScalarUDFImpl` to maintain backwards
45+
/// compatibility with the older API.
46+
///
4347
/// [`create_udf`]: crate::expr_fn::create_udf
4448
/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
4549
/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
46-
#[derive(Clone)]
50+
#[derive(Debug, Clone)]
4751
pub struct ScalarUDF {
48-
/// The name of the function
49-
name: String,
50-
/// The signature (the types of arguments that are supported)
51-
signature: Signature,
52-
/// Function that returns the return type given the argument types
53-
return_type: ReturnTypeFunction,
54-
/// actual implementation
55-
///
56-
/// The fn param is the wrapped function but be aware that the function will
57-
/// be passed with the slice / vec of columnar values (either scalar or array)
58-
/// with the exception of zero param function, where a singular element vec
59-
/// will be passed. In that case the single element is a null array to indicate
60-
/// the batch's row count (so that the generative zero-argument function can know
61-
/// the result array size).
62-
fun: ScalarFunctionImplementation,
63-
/// Optional aliases for the function. This list should NOT include the value of `name` as well
64-
aliases: Vec<String>,
65-
}
66-
67-
impl Debug for ScalarUDF {
68-
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
69-
f.debug_struct("ScalarUDF")
70-
.field("name", &self.name)
71-
.field("signature", &self.signature)
72-
.field("fun", &"<FUNC>")
73-
.finish()
74-
}
52+
inner: Arc<dyn ScalarUDFImpl>,
7553
}
7654

7755
impl PartialEq for ScalarUDF {
7856
fn eq(&self, other: &Self) -> bool {
79-
self.name == other.name && self.signature == other.signature
57+
self.name() == other.name() && self.signature() == other.signature()
8058
}
8159
}
8260

8361
impl Eq for ScalarUDF {}
8462

8563
impl std::hash::Hash for ScalarUDF {
8664
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87-
self.name.hash(state);
88-
self.signature.hash(state);
65+
self.name().hash(state);
66+
self.signature().hash(state);
8967
}
9068
}
9169

@@ -101,51 +79,37 @@ impl ScalarUDF {
10179
return_type: &ReturnTypeFunction,
10280
fun: &ScalarFunctionImplementation,
10381
) -> Self {
104-
Self {
82+
Self::new_from_impl(ScalarUdfLegacyWrapper {
10583
name: name.to_owned(),
10684
signature: signature.clone(),
10785
return_type: return_type.clone(),
10886
fun: fun.clone(),
109-
aliases: vec![],
110-
}
87+
})
11188
}
11289

11390
/// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
11491
///
11592
/// Note this is the same as using the `From` impl (`ScalarUDF::from`)
11693
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
11794
where
118-
F: ScalarUDFImpl + Send + Sync + 'static,
95+
F: ScalarUDFImpl + 'static,
11996
{
120-
// TODO change the internal implementation to use the trait object
121-
let arc_fun = Arc::new(fun);
122-
let captured_self = arc_fun.clone();
123-
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
124-
let return_type = captured_self.return_type(arg_types)?;
125-
Ok(Arc::new(return_type))
126-
});
127-
128-
let captured_self = arc_fun.clone();
129-
let func: ScalarFunctionImplementation =
130-
Arc::new(move |args| captured_self.invoke(args));
131-
13297
Self {
133-
name: arc_fun.name().to_string(),
134-
signature: arc_fun.signature().clone(),
135-
return_type: return_type.clone(),
136-
fun: func,
137-
aliases: arc_fun.aliases().to_vec(),
98+
inner: Arc::new(fun),
13899
}
139100
}
140101

141-
/// Adds additional names that can be used to invoke this function, in addition to `name`
142-
pub fn with_aliases(
143-
mut self,
144-
aliases: impl IntoIterator<Item = &'static str>,
145-
) -> Self {
146-
self.aliases
147-
.extend(aliases.into_iter().map(|s| s.to_string()));
148-
self
102+
/// Return the underlying [`ScalarUDFImpl`] trait object for this function
103+
pub fn inner(&self) -> Arc<dyn ScalarUDFImpl> {
104+
self.inner.clone()
105+
}
106+
107+
/// Adds additional names that can be used to invoke this function, in
108+
/// addition to `name`
109+
///
110+
/// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
111+
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
112+
Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
149113
}
150114

151115
/// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -159,31 +123,46 @@ impl ScalarUDF {
159123
))
160124
}
161125

162-
/// Returns this function's name
126+
/// Returns this function's name.
127+
///
128+
/// See [`ScalarUDFImpl::name`] for more details.
163129
pub fn name(&self) -> &str {
164-
&self.name
130+
self.inner.name()
165131
}
166132

167-
/// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
133+
/// Returns the aliases for this function.
134+
///
135+
/// See [`ScalarUDF::with_aliases`] for more details
168136
pub fn aliases(&self) -> &[String] {
169-
&self.aliases
137+
self.inner.aliases()
170138
}
171139

172-
/// Returns this function's [`Signature`] (what input types are accepted)
140+
/// Returns this function's [`Signature`] (what input types are accepted).
141+
///
142+
/// See [`ScalarUDFImpl::signature`] for more details.
173143
pub fn signature(&self) -> &Signature {
174-
&self.signature
144+
self.inner.signature()
175145
}
176146

177-
/// The datatype this function returns given the input argument input types
147+
/// The datatype this function returns given the input argument input types.
148+
///
149+
/// See [`ScalarUDFImpl::return_type`] for more details.
178150
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
179-
// Old API returns an Arc of the datatype for some reason
180-
let res = (self.return_type)(args)?;
181-
Ok(res.as_ref().clone())
151+
self.inner.return_type(args)
152+
}
153+
154+
/// Invoke the function on `args`, returning the appropriate result.
155+
///
156+
/// See [`ScalarUDFImpl::invoke`] for more details.
157+
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
158+
self.inner.invoke(args)
182159
}
183160

184-
/// Return an [`Arc`] to the function implementation
161+
/// Returns a `ScalarFunctionImplementation` that can invoke the function
162+
/// during execution
185163
pub fn fun(&self) -> ScalarFunctionImplementation {
186-
self.fun.clone()
164+
let captured = self.inner.clone();
165+
Arc::new(move |args| captured.invoke(args))
187166
}
188167
}
189168

@@ -213,6 +192,7 @@ where
213192
/// # use datafusion_common::{DataFusionError, plan_err, Result};
214193
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
215194
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
195+
/// #[derive(Debug)]
216196
/// struct AddOne {
217197
/// signature: Signature
218198
/// };
@@ -246,7 +226,7 @@ where
246226
/// // Call the function `add_one(col)`
247227
/// let expr = add_one.call(vec![col("a")]);
248228
/// ```
249-
pub trait ScalarUDFImpl {
229+
pub trait ScalarUDFImpl: Debug + Send + Sync {
250230
/// Returns this object as an [`Any`] trait object
251231
fn as_any(&self) -> &dyn Any;
252232

@@ -292,3 +272,106 @@ pub trait ScalarUDFImpl {
292272
&[]
293273
}
294274
}
275+
276+
/// ScalarUDF that adds an alias to the underlying function. It is better to
277+
/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
278+
#[derive(Debug)]
279+
struct AliasedScalarUDFImpl {
280+
inner: ScalarUDF,
281+
aliases: Vec<String>,
282+
}
283+
284+
impl AliasedScalarUDFImpl {
285+
pub fn new(
286+
inner: ScalarUDF,
287+
new_aliases: impl IntoIterator<Item = &'static str>,
288+
) -> Self {
289+
let mut aliases = inner.aliases().to_vec();
290+
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
291+
292+
Self { inner, aliases }
293+
}
294+
}
295+
296+
impl ScalarUDFImpl for AliasedScalarUDFImpl {
297+
fn as_any(&self) -> &dyn Any {
298+
self
299+
}
300+
fn name(&self) -> &str {
301+
self.inner.name()
302+
}
303+
304+
fn signature(&self) -> &Signature {
305+
self.inner.signature()
306+
}
307+
308+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
309+
self.inner.return_type(arg_types)
310+
}
311+
312+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
313+
self.inner.invoke(args)
314+
}
315+
316+
fn aliases(&self) -> &[String] {
317+
&self.aliases
318+
}
319+
}
320+
321+
/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers
322+
/// of the older API (see <https://github.com/apache/arrow-datafusion/pull/8578>
323+
/// for more details)
324+
struct ScalarUdfLegacyWrapper {
325+
/// The name of the function
326+
name: String,
327+
/// The signature (the types of arguments that are supported)
328+
signature: Signature,
329+
/// Function that returns the return type given the argument types
330+
return_type: ReturnTypeFunction,
331+
/// actual implementation
332+
///
333+
/// The fn param is the wrapped function but be aware that the function will
334+
/// be passed with the slice / vec of columnar values (either scalar or array)
335+
/// with the exception of zero param function, where a singular element vec
336+
/// will be passed. In that case the single element is a null array to indicate
337+
/// the batch's row count (so that the generative zero-argument function can know
338+
/// the result array size).
339+
fun: ScalarFunctionImplementation,
340+
}
341+
342+
impl Debug for ScalarUdfLegacyWrapper {
343+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
344+
f.debug_struct("ScalarUDF")
345+
.field("name", &self.name)
346+
.field("signature", &self.signature)
347+
.field("fun", &"<FUNC>")
348+
.finish()
349+
}
350+
}
351+
352+
impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
353+
fn as_any(&self) -> &dyn Any {
354+
self
355+
}
356+
fn name(&self) -> &str {
357+
&self.name
358+
}
359+
360+
fn signature(&self) -> &Signature {
361+
&self.signature
362+
}
363+
364+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
365+
// Old API returns an Arc of the datatype for some reason
366+
let res = (self.return_type)(arg_types)?;
367+
Ok(res.as_ref().clone())
368+
}
369+
370+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
371+
(self.fun)(args)
372+
}
373+
374+
fn aliases(&self) -> &[String] {
375+
&[]
376+
}
377+
}

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ mod test {
811811

812812
static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();
813813

814+
#[derive(Debug, Clone, Default)]
814815
struct TestScalarUDF {}
815816
impl ScalarUDFImpl for TestScalarUDF {
816817
fn as_any(&self) -> &dyn Any {

datafusion/physical-expr/src/udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub fn create_physical_expr(
3636

3737
Ok(Arc::new(ScalarFunctionExpr::new(
3838
fun.name(),
39-
fun.fun().clone(),
39+
fun.fun(),
4040
input_phy_exprs.to_vec(),
4141
fun.return_type(&input_exprs_types)?,
4242
None,

0 commit comments

Comments
 (0)