Skip to content

Commit c689980

Browse files
committed
refactor api
Signed-off-by: jayzhan211 <[email protected]>
1 parent f25c1df commit c689980

File tree

7 files changed

+78
-33
lines changed

7 files changed

+78
-33
lines changed

datafusion-examples/examples/udaf_expr.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use datafusion::{
2424
};
2525

2626
use datafusion_common::Result;
27-
use datafusion_expr::col;
27+
use datafusion_expr::{col, AggregateUDFExprBuilder};
2828

2929
#[tokio::main]
3030
async fn main() -> Result<()> {
@@ -33,11 +33,10 @@ async fn main() -> Result<()> {
3333
let mut state = SessionState::new_with_config_rt(config, ctx.runtime_env());
3434
let _ = register_all(&mut state);
3535

36-
let first_value_udaf = state.aggregate_functions().get("FIRST_VALUE").unwrap();
36+
let first_value_udaf = state.aggregate_functions().get("first_value").unwrap();
3737
let first_value_builder = first_value_udaf
3838
.call(vec![col("a")])
39-
.order_by(vec![col("b")])
40-
.build();
39+
.order_by(vec![col("b")]);
4140

4241
let first_value_fn = first_value(col("a"), Some(vec![col("b")]));
4342
assert_eq!(first_value_builder, first_value_fn);

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ pub use signature::{
8181
ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
8282
};
8383
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
84-
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
84+
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, AggregateUDFExprBuilder};
8585
pub use udf::{ScalarUDF, ScalarUDFImpl};
8686
pub use udwf::{WindowUDF, WindowUDFImpl};
8787
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

datafusion/expr/src/udaf.rs

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::{Accumulator, Expr};
2828
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2929
use arrow::datatypes::{DataType, Field};
3030
use datafusion_common::{exec_err, not_impl_err, Result};
31+
use sqlparser::ast::NullTreatment;
3132
use std::any::Any;
3233
use std::fmt::{self, Debug, Formatter};
3334
use std::sync::Arc;
@@ -139,8 +140,15 @@ impl AggregateUDF {
139140
///
140141
/// This utility allows using the UDAF without requiring access to
141142
/// the registry, such as with the DataFrame API.
142-
pub fn call(&self, args: Vec<Expr>) -> AggregateFunction {
143-
AggregateFunction::new_udf(Arc::new(self.clone()), args, false, None, None, None)
143+
pub fn call(&self, args: Vec<Expr>) -> Expr {
144+
Expr::AggregateFunction(AggregateFunction::new_udf(
145+
Arc::new(self.clone()),
146+
args,
147+
false,
148+
None,
149+
None,
150+
None,
151+
))
144152
}
145153

146154
/// Returns this function's name
@@ -599,3 +607,49 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
599607
(self.accumulator)(acc_args)
600608
}
601609
}
610+
611+
pub trait AggregateUDFExprBuilder {
612+
fn order_by(self, order_by: Vec<Expr>) -> Expr;
613+
fn filter(self, filter: Box<Expr>) -> Expr;
614+
fn null_treatment(self, null_treatment: NullTreatment) -> Expr;
615+
fn distinct(self) -> Expr;
616+
}
617+
618+
impl AggregateUDFExprBuilder for Expr {
619+
fn order_by(self, order_by: Vec<Expr>) -> Expr {
620+
match self {
621+
Expr::AggregateFunction(mut udaf) => {
622+
udaf.order_by = Some(order_by);
623+
Expr::AggregateFunction(udaf)
624+
}
625+
_ => self,
626+
}
627+
}
628+
fn filter(self, filter: Box<Expr>) -> Expr {
629+
match self {
630+
Expr::AggregateFunction(mut udaf) => {
631+
udaf.filter = Some(filter);
632+
Expr::AggregateFunction(udaf)
633+
}
634+
_ => self,
635+
}
636+
}
637+
fn null_treatment(self, null_treatment: NullTreatment) -> Expr {
638+
match self {
639+
Expr::AggregateFunction(mut udaf) => {
640+
udaf.null_treatment = Some(null_treatment);
641+
Expr::AggregateFunction(udaf)
642+
}
643+
_ => self,
644+
}
645+
}
646+
fn distinct(self) -> Expr {
647+
match self {
648+
Expr::AggregateFunction(mut udaf) => {
649+
udaf.distinct = true;
650+
Expr::AggregateFunction(udaf)
651+
}
652+
_ => self,
653+
}
654+
}
655+
}

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
2828
use datafusion_common::{
2929
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
3030
};
31-
use datafusion_expr::expr::AggregateFunction;
3231
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3332
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
3433
use datafusion_expr::{
35-
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature,
36-
Volatility,
34+
Accumulator, AggregateUDFExprBuilder, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility
3735
};
3836
use datafusion_physical_expr_common::aggregate::utils::get_sort_options;
3937
use datafusion_physical_expr_common::sort_expr::{
@@ -44,14 +42,11 @@ create_func!(FirstValue, first_value_udaf);
4442

4543
/// Returns the first value in a group of values.
4644
pub fn first_value(expression: Expr, order_by: Option<Vec<Expr>>) -> Expr {
47-
Expr::AggregateFunction(AggregateFunction::new_udf(
48-
first_value_udaf(),
49-
vec![expression],
50-
false,
51-
None,
52-
order_by,
53-
None,
54-
))
45+
if let Some(order_by) = order_by {
46+
first_value_udaf().call(vec![expression]).order_by(order_by)
47+
} else {
48+
first_value_udaf().call(vec![expression])
49+
}
5550
}
5651

5752
pub struct FirstValue {

datafusion/optimizer/src/replace_distinct_aggregate.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule};
2121

2222
use datafusion_common::tree_node::Transformed;
2323
use datafusion_common::{internal_err, Column, Result};
24-
use datafusion_expr::expr::AggregateFunction;
2524
use datafusion_expr::expr_rewriter::normalize_cols;
2625
use datafusion_expr::utils::expand_wildcard;
27-
use datafusion_expr::{col, LogicalPlanBuilder};
26+
use datafusion_expr::{col, AggregateUDFExprBuilder, LogicalPlanBuilder};
2827
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
2928

3029
/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
@@ -95,17 +94,14 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
9594
let expr_cnt = on_expr.len();
9695

9796
// Construct the aggregation expression to be used to fetch the selected expressions.
98-
let first_value_udaf =
97+
let first_value_udaf: std::sync::Arc<datafusion_expr::AggregateUDF> =
9998
config.function_registry().unwrap().udaf("first_value")?;
10099
let aggr_expr = select_expr.into_iter().map(|e| {
101-
Expr::AggregateFunction(AggregateFunction::new_udf(
102-
first_value_udaf.clone(),
103-
vec![e],
104-
false,
105-
None,
106-
sort_expr.clone(),
107-
None,
108-
))
100+
if let Some(order_by) = &sort_expr {
101+
first_value_udaf.call(vec![e]).order_by(order_by.clone())
102+
} else {
103+
first_value_udaf.call(vec![e])
104+
}
109105
});
110106

111107
let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,7 @@ async fn roundtrip_expr_api() -> Result<()> {
647647
lit(1),
648648
),
649649
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)),
650+
first_value(lit(1), None),
650651
first_value(lit(1), Some(vec![lit(2)])),
651652
covar_samp(lit(1.5), lit(2.2)),
652653
covar_pop(lit(1.5), lit(2.2)),

docs/source/user-guide/expressions.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ select log(-1), log(0), sqrt(-1);
306306

307307
## Aggregate Function Builder
308308

309-
Another builder expression that ends with `build()`, it is useful if the functions has multiple optional arguments
309+
Import trait `AggregateUDFExprBuilder` and update the arguments directly in `Expr`
310310

311311
See datafusion-examples/examples/udaf_expr.rs for example usage.
312312

313-
| Syntax | Equivalent to |
314-
| -------------------------------------------------------------- | ----------------------------------- |
315-
| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build() | first_value(expr, Some(vec![expr])) |
313+
| Syntax | Equivalent to |
314+
| ------------------------------------------------------ | ----------------------------------- |
315+
| first_value_udaf.call(vec![expr]).order_by(vec![expr]) | first_value(expr, Some(vec![expr])) |
316316

317317
## Subquery Expressions
318318

0 commit comments

Comments
 (0)