Skip to content

Commit 1334667

Browse files
committed
Implement ScalarUDF in terms of ScalarUDFImpl trait
1 parent d2b3d1c commit 1334667

File tree

6 files changed

+170
-75
lines changed

6 files changed

+170
-75
lines changed

datafusion-examples/examples/advanced_udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use std::sync::Arc;
4040
/// the power of the second argument `a^b`.
4141
///
4242
/// To do so, we must implement the `ScalarUDFImpl` trait.
43+
#[derive(Debug, Clone)]
4344
struct PowUdf {
4445
signature: Signature,
4546
aliases: Vec<String>,

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,7 @@ mod test {
19481948
);
19491949

19501950
// UDF
1951+
#[derive(Debug)]
19511952
struct TestScalarUDF {
19521953
signature: Signature,
19531954
}

datafusion/expr/src/expr_fn.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF};
3232
use arrow::datatypes::DataType;
3333
use datafusion_common::{Column, Result};
3434
use std::any::Any;
35+
use std::fmt::Debug;
3536
use std::ops::Not;
3637
use std::sync::Arc;
3738

@@ -983,6 +984,16 @@ pub struct SimpleScalarUDF {
983984
fun: ScalarFunctionImplementation,
984985
}
985986

987+
impl Debug for SimpleScalarUDF {
988+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
989+
f.debug_struct("ScalarUDF")
990+
.field("name", &self.name)
991+
.field("signature", &self.signature)
992+
.field("fun", &"<FUNC>")
993+
.finish()
994+
}
995+
}
996+
986997
impl SimpleScalarUDF {
987998
/// Create a new `SimpleScalarUDF` from a name, input types, return type and
988999
/// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility

datafusion/expr/src/udf.rs

Lines changed: 155 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -35,57 +35,35 @@ use std::sync::Arc;
3535
/// functions you supply such name, type signature, return type, and actual
3636
/// implementation.
3737
///
38-
///
3938
/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
4039
///
4140
/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
4241
///
42+
/// # API Note
43+
///
44+
/// This is a separate struct from `ScalarUDFImpl` to maintain backwards
45+
/// compatibility with the older API.
46+
///
4347
/// [`create_udf`]: crate::expr_fn::create_udf
4448
/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
4549
/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
46-
#[derive(Clone)]
50+
#[derive(Debug, Clone)]
4751
pub struct ScalarUDF {
48-
/// The name of the function
49-
name: String,
50-
/// The signature (the types of arguments that are supported)
51-
signature: Signature,
52-
/// Function that returns the return type given the argument types
53-
return_type: ReturnTypeFunction,
54-
/// actual implementation
55-
///
56-
/// The fn param is the wrapped function but be aware that the function will
57-
/// be passed with the slice / vec of columnar values (either scalar or array)
58-
/// with the exception of zero param function, where a singular element vec
59-
/// will be passed. In that case the single element is a null array to indicate
60-
/// the batch's row count (so that the generative zero-argument function can know
61-
/// the result array size).
62-
fun: ScalarFunctionImplementation,
63-
/// Optional aliases for the function. This list should NOT include the value of `name` as well
64-
aliases: Vec<String>,
65-
}
66-
67-
impl Debug for ScalarUDF {
68-
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
69-
f.debug_struct("ScalarUDF")
70-
.field("name", &self.name)
71-
.field("signature", &self.signature)
72-
.field("fun", &"<FUNC>")
73-
.finish()
74-
}
52+
inner: Arc<dyn ScalarUDFImpl>,
7553
}
7654

7755
impl PartialEq for ScalarUDF {
7856
fn eq(&self, other: &Self) -> bool {
79-
self.name == other.name && self.signature == other.signature
57+
self.name() == other.name() && self.signature() == other.signature()
8058
}
8159
}
8260

8361
impl Eq for ScalarUDF {}
8462

8563
impl std::hash::Hash for ScalarUDF {
8664
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87-
self.name.hash(state);
88-
self.signature.hash(state);
65+
self.name().hash(state);
66+
self.signature().hash(state);
8967
}
9068
}
9169

@@ -101,51 +79,37 @@ impl ScalarUDF {
10179
return_type: &ReturnTypeFunction,
10280
fun: &ScalarFunctionImplementation,
10381
) -> Self {
104-
Self {
82+
Self::new_from_impl(ScalarUdfLegacyWrapper {
10583
name: name.to_owned(),
10684
signature: signature.clone(),
10785
return_type: return_type.clone(),
10886
fun: fun.clone(),
109-
aliases: vec![],
110-
}
87+
})
11188
}
11289

11390
/// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
11491
///
11592
/// Note this is the same as using the `From` impl (`ScalarUDF::from`)
11693
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
11794
where
118-
F: ScalarUDFImpl + Send + Sync + 'static,
95+
F: ScalarUDFImpl + 'static,
11996
{
120-
// TODO change the internal implementation to use the trait object
121-
let arc_fun = Arc::new(fun);
122-
let captured_self = arc_fun.clone();
123-
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
124-
let return_type = captured_self.return_type(arg_types)?;
125-
Ok(Arc::new(return_type))
126-
});
127-
128-
let captured_self = arc_fun.clone();
129-
let func: ScalarFunctionImplementation =
130-
Arc::new(move |args| captured_self.invoke(args));
131-
13297
Self {
133-
name: arc_fun.name().to_string(),
134-
signature: arc_fun.signature().clone(),
135-
return_type: return_type.clone(),
136-
fun: func,
137-
aliases: arc_fun.aliases().to_vec(),
98+
inner: Arc::new(fun),
13899
}
139100
}
140101

141-
/// Adds additional names that can be used to invoke this function, in addition to `name`
142-
pub fn with_aliases(
143-
mut self,
144-
aliases: impl IntoIterator<Item = &'static str>,
145-
) -> Self {
146-
self.aliases
147-
.extend(aliases.into_iter().map(|s| s.to_string()));
148-
self
102+
/// Return the underlying [`ScalarUDFImpl`] trait object for this function
103+
pub fn inner(&self) -> Arc<dyn ScalarUDFImpl> {
104+
self.inner.clone()
105+
}
106+
107+
/// Adds additional names that can be used to invoke this function, in
108+
/// addition to `name`
109+
///
110+
/// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
111+
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
112+
Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
149113
}
150114

151115
/// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -159,31 +123,46 @@ impl ScalarUDF {
159123
))
160124
}
161125

162-
/// Returns this function's name
126+
/// Returns this function's name.
127+
///
128+
/// See [`ScalarUDFImpl::name`] for more details.
163129
pub fn name(&self) -> &str {
164-
&self.name
130+
self.inner.name()
165131
}
166132

167-
/// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
133+
/// Returns the aliases for this function.
134+
///
135+
/// See [`ScalarUDF::with_aliases`] for more details
168136
pub fn aliases(&self) -> &[String] {
169-
&self.aliases
137+
self.inner.aliases()
170138
}
171139

172-
/// Returns this function's [`Signature`] (what input types are accepted)
140+
/// Returns this function's [`Signature`] (what input types are accepted).
141+
///
142+
/// See [`ScalarUDFImpl::signature`] for more details.
173143
pub fn signature(&self) -> &Signature {
174-
&self.signature
144+
self.inner.signature()
175145
}
176146

177-
/// The datatype this function returns given the input argument input types
147+
/// The datatype this function returns given the input argument input types.
148+
///
149+
/// See [`ScalarUDFImpl::return_type`] for more details.
178150
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
179-
// Old API returns an Arc of the datatype for some reason
180-
let res = (self.return_type)(args)?;
181-
Ok(res.as_ref().clone())
151+
self.inner.return_type(args)
152+
}
153+
154+
/// Invoke the function on `args`, returning the appropriate result.
155+
///
156+
/// See [`ScalarUDFImpl::invoke`] for more details.
157+
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
158+
self.inner.invoke(args)
182159
}
183160

184-
/// Return an [`Arc`] to the function implementation
161+
/// Returns a `ScalarFunctionImplementation` that can invoke the function
162+
/// during execution
185163
pub fn fun(&self) -> ScalarFunctionImplementation {
186-
self.fun.clone()
164+
let captured = self.inner.clone();
165+
Arc::new(move |args| captured.invoke(args))
187166
}
188167
}
189168

@@ -246,7 +225,7 @@ where
246225
/// // Call the function `add_one(col)`
247226
/// let expr = add_one.call(vec![col("a")]);
248227
/// ```
249-
pub trait ScalarUDFImpl {
228+
pub trait ScalarUDFImpl: Debug + Send + Sync {
250229
/// Returns this object as an [`Any`] trait object
251230
fn as_any(&self) -> &dyn Any;
252231

@@ -292,3 +271,105 @@ pub trait ScalarUDFImpl {
292271
&[]
293272
}
294273
}
274+
275+
/// ScalarUDF that adds an alias to the underlying function. It is better to
276+
/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
277+
#[derive(Debug)]
278+
struct AliasedScalarUDFImpl {
279+
inner: ScalarUDF,
280+
aliases: Vec<String>,
281+
}
282+
283+
impl AliasedScalarUDFImpl {
284+
pub fn new(
285+
inner: ScalarUDF,
286+
new_aliases: impl IntoIterator<Item = &'static str>,
287+
) -> Self {
288+
let mut aliases = inner.aliases().to_vec();
289+
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
290+
291+
Self { inner, aliases }
292+
}
293+
}
294+
295+
impl ScalarUDFImpl for AliasedScalarUDFImpl {
296+
fn as_any(&self) -> &dyn Any {
297+
self
298+
}
299+
fn name(&self) -> &str {
300+
self.inner.name()
301+
}
302+
303+
fn signature(&self) -> &Signature {
304+
self.inner.signature()
305+
}
306+
307+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
308+
self.inner.return_type(arg_types)
309+
}
310+
311+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
312+
self.inner.invoke(args)
313+
}
314+
315+
fn aliases(&self) -> &[String] {
316+
&self.aliases
317+
}
318+
}
319+
320+
/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers of the older API
321+
/// (see https://github.com/apache/arrow-datafusion/pull/8578)
322+
struct ScalarUdfLegacyWrapper {
323+
/// The name of the function
324+
name: String,
325+
/// The signature (the types of arguments that are supported)
326+
signature: Signature,
327+
/// Function that returns the return type given the argument types
328+
return_type: ReturnTypeFunction,
329+
/// actual implementation
330+
///
331+
/// The fn param is the wrapped function but be aware that the function will
332+
/// be passed with the slice / vec of columnar values (either scalar or array)
333+
/// with the exception of zero param function, where a singular element vec
334+
/// will be passed. In that case the single element is a null array to indicate
335+
/// the batch's row count (so that the generative zero-argument function can know
336+
/// the result array size).
337+
fun: ScalarFunctionImplementation,
338+
}
339+
340+
impl Debug for ScalarUdfLegacyWrapper {
341+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
342+
f.debug_struct("ScalarUDF")
343+
.field("name", &self.name)
344+
.field("signature", &self.signature)
345+
.field("fun", &"<FUNC>")
346+
.finish()
347+
}
348+
}
349+
350+
impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
351+
fn as_any(&self) -> &dyn Any {
352+
self
353+
}
354+
fn name(&self) -> &str {
355+
&self.name
356+
}
357+
358+
fn signature(&self) -> &Signature {
359+
&self.signature
360+
}
361+
362+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
363+
// Old API returns an Arc of the datatype for some reason
364+
let res = (self.return_type)(arg_types)?;
365+
Ok(res.as_ref().clone())
366+
}
367+
368+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
369+
(self.fun)(args)
370+
}
371+
372+
fn aliases(&self) -> &[String] {
373+
&[]
374+
}
375+
}

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ mod test {
811811

812812
static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();
813813

814+
#[derive(Debug, Clone, Default)]
814815
struct TestScalarUDF {}
815816
impl ScalarUDFImpl for TestScalarUDF {
816817
fn as_any(&self) -> &dyn Any {

datafusion/physical-expr/src/udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub fn create_physical_expr(
3636

3737
Ok(Arc::new(ScalarFunctionExpr::new(
3838
fun.name(),
39-
fun.fun().clone(),
39+
fun.fun(),
4040
input_phy_exprs.to_vec(),
4141
fun.return_type(&input_exprs_types)?,
4242
None,

0 commit comments

Comments
 (0)