Skip to content

Commit c419465

Browse files
goldmedalfindepi
authored andcommitted
Convert ApproxPercentileCont and ApproxPercentileContWithWeight to UDAF (apache#10917)
* pass logical expr of arguments for udaf * implement approx_percentile_cont udaf * register udaf * remove ApproxPercentileCont * convert with_wegiht to udaf and remove original * fix conflict * fix compile check * fix doc and testing * evaluate args through physical plan * public use Literal * fix tests * rollback the experimental tests * remove unused import * rename args and inline code * remove unnecessary partial eq trait * fix error message
1 parent 7ce24ef commit c419465

File tree

39 files changed

+443
-714
lines changed

39 files changed

+443
-714
lines changed

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ pub(crate) mod tests {
390390
&[self.column()],
391391
&[],
392392
&[],
393+
&[],
393394
schema,
394395
self.column_name(),
395396
false,

datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ mod tests {
315315
&[expr],
316316
&[],
317317
&[],
318+
&[],
318319
schema,
319320
name,
320321
false,
@@ -404,6 +405,7 @@ mod tests {
404405
&[col("b", &schema)?],
405406
&[],
406407
&[],
408+
&[],
407409
&schema,
408410
"Sum(b)",
409411
false,

datafusion/core/src/physical_optimizer/test_utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ pub fn bounded_window_exec(
245245
"count".to_owned(),
246246
&[col(col_name, &schema).unwrap()],
247247
&[],
248+
&[],
248249
&sort_exprs,
249250
Arc::new(WindowFrame::new(Some(false))),
250251
schema.as_ref(),

datafusion/core/src/physical_planner.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,8 @@ pub fn create_window_expr_with_name(
17661766
window_frame,
17671767
null_treatment,
17681768
}) => {
1769-
let args = create_physical_exprs(args, logical_schema, execution_props)?;
1769+
let physical_args =
1770+
create_physical_exprs(args, logical_schema, execution_props)?;
17701771
let partition_by =
17711772
create_physical_exprs(partition_by, logical_schema, execution_props)?;
17721773
let order_by =
@@ -1780,13 +1781,13 @@ pub fn create_window_expr_with_name(
17801781
}
17811782

17821783
let window_frame = Arc::new(window_frame.clone());
1783-
let ignore_nulls = null_treatment
1784-
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
1784+
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
17851785
== NullTreatment::IgnoreNulls;
17861786
windows::create_window_expr(
17871787
fun,
17881788
name,
1789-
&args,
1789+
&physical_args,
1790+
args,
17901791
&partition_by,
17911792
&order_by,
17921793
window_frame,
@@ -1837,7 +1838,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18371838
order_by,
18381839
null_treatment,
18391840
}) => {
1840-
let args =
1841+
let physical_args =
18411842
create_physical_exprs(args, logical_input_schema, execution_props)?;
18421843
let filter = match filter {
18431844
Some(e) => Some(create_physical_expr(
@@ -1867,7 +1868,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18671868
let agg_expr = aggregates::create_aggregate_expr(
18681869
fun,
18691870
*distinct,
1870-
&args,
1871+
&physical_args,
18711872
&ordering_reqs,
18721873
physical_input_schema,
18731874
name,
@@ -1889,7 +1890,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18891890
physical_sort_exprs.clone().unwrap_or(vec![]);
18901891
let agg_expr = udaf::create_aggregate_expr(
18911892
fun,
1892-
&args,
1893+
&physical_args,
1894+
args,
18931895
&sort_exprs,
18941896
&ordering_reqs,
18951897
physical_input_schema,

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion::assert_batches_eq;
3333
use datafusion_common::{DFSchema, ScalarValue};
3434
use datafusion_expr::expr::Alias;
3535
use datafusion_expr::ExprSchemable;
36-
use datafusion_functions_aggregate::expr_fn::approx_median;
36+
use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont};
3737

3838
fn test_schema() -> SchemaRef {
3939
Arc::new(Schema::new(vec![
@@ -363,7 +363,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
363363

364364
let expected = [
365365
"+---------------------------------------------+",
366-
"| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |",
366+
"| approx_percentile_cont(test.b,Float64(0.5)) |",
367367
"+---------------------------------------------+",
368368
"| 10 |",
369369
"+---------------------------------------------+",
@@ -384,7 +384,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
384384
let df = create_test_table().await?;
385385
let expected = [
386386
"+--------------------------------------+",
387-
"| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
387+
"| approx_percentile_cont(test.b,arg_2) |",
388388
"+--------------------------------------+",
389389
"| 10 |",
390390
"+--------------------------------------+",

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, group_by_columns: Vec<&str
108108
&[col("d", &schema).unwrap()],
109109
&[],
110110
&[],
111+
&[],
111112
&schema,
112113
"sum1",
113114
false,

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
252252

253253
let partitionby_exprs = vec![];
254254
let orderby_exprs = vec![];
255+
let logical_exprs = vec![];
255256
// Window frame starts with "UNBOUNDED PRECEDING":
256257
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
257258

@@ -283,6 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
283284
&window_fn,
284285
fn_name.to_string(),
285286
&args,
287+
&logical_exprs,
286288
&partitionby_exprs,
287289
&orderby_exprs,
288290
Arc::new(window_frame),
@@ -699,6 +701,7 @@ async fn run_window_test(
699701
&window_fn,
700702
fn_name.clone(),
701703
&args,
704+
&[],
702705
&partitionby_exprs,
703706
&orderby_exprs,
704707
Arc::new(window_frame.clone()),
@@ -717,6 +720,7 @@ async fn run_window_test(
717720
&window_fn,
718721
fn_name,
719722
&args,
723+
&[],
720724
&partitionby_exprs,
721725
&orderby_exprs,
722726
Arc::new(window_frame.clone()),

datafusion/expr/src/aggregate_function.rs

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::sync::Arc;
2121
use std::{fmt, str::FromStr};
2222

2323
use crate::utils;
24-
use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility};
24+
use crate::{type_coercion::aggregates::*, Signature, Volatility};
2525

2626
use arrow::datatypes::{DataType, Field};
2727
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
@@ -45,10 +45,6 @@ pub enum AggregateFunction {
4545
NthValue,
4646
/// Correlation
4747
Correlation,
48-
/// Approximate continuous percentile function
49-
ApproxPercentileCont,
50-
/// Approximate continuous percentile function with weight
51-
ApproxPercentileContWithWeight,
5248
/// Grouping
5349
Grouping,
5450
/// Bit And
@@ -75,8 +71,6 @@ impl AggregateFunction {
7571
ArrayAgg => "ARRAY_AGG",
7672
NthValue => "NTH_VALUE",
7773
Correlation => "CORR",
78-
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
79-
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
8074
Grouping => "GROUPING",
8175
BitAnd => "BIT_AND",
8276
BitOr => "BIT_OR",
@@ -113,11 +107,6 @@ impl FromStr for AggregateFunction {
113107
"string_agg" => AggregateFunction::StringAgg,
114108
// statistical
115109
"corr" => AggregateFunction::Correlation,
116-
// approximate
117-
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
118-
"approx_percentile_cont_with_weight" => {
119-
AggregateFunction::ApproxPercentileContWithWeight
120-
}
121110
// other
122111
"grouping" => AggregateFunction::Grouping,
123112
_ => {
@@ -170,10 +159,6 @@ impl AggregateFunction {
170159
coerced_data_types[0].clone(),
171160
true,
172161
)))),
173-
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
174-
AggregateFunction::ApproxPercentileContWithWeight => {
175-
Ok(coerced_data_types[0].clone())
176-
}
177162
AggregateFunction::Grouping => Ok(DataType::Int32),
178163
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
179164
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
@@ -230,39 +215,6 @@ impl AggregateFunction {
230215
AggregateFunction::Correlation => {
231216
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
232217
}
233-
AggregateFunction::ApproxPercentileCont => {
234-
let mut variants =
235-
Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
236-
// Accept any numeric value paired with a float64 percentile
237-
for num in NUMERICS {
238-
variants
239-
.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
240-
// Additionally accept an integer number of centroids for T-Digest
241-
for int in INTEGERS {
242-
variants.push(TypeSignature::Exact(vec![
243-
num.clone(),
244-
DataType::Float64,
245-
int.clone(),
246-
]))
247-
}
248-
}
249-
250-
Signature::one_of(variants, Volatility::Immutable)
251-
}
252-
AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
253-
// Accept any numeric value paired with a float64 percentile
254-
NUMERICS
255-
.iter()
256-
.map(|t| {
257-
TypeSignature::Exact(vec![
258-
t.clone(),
259-
t.clone(),
260-
DataType::Float64,
261-
])
262-
})
263-
.collect(),
264-
Volatility::Immutable,
265-
),
266218
AggregateFunction::StringAgg => {
267219
Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
268220
}

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -242,34 +242,6 @@ pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
242242
Expr::InList(InList::new(Box::new(expr), list, negated))
243243
}
244244

245-
/// Calculate an approximation of the specified `percentile` for `expr`.
246-
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
247-
Expr::AggregateFunction(AggregateFunction::new(
248-
aggregate_function::AggregateFunction::ApproxPercentileCont,
249-
vec![expr, percentile],
250-
false,
251-
None,
252-
None,
253-
None,
254-
))
255-
}
256-
257-
/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`.
258-
pub fn approx_percentile_cont_with_weight(
259-
expr: Expr,
260-
weight_expr: Expr,
261-
percentile: Expr,
262-
) -> Expr {
263-
Expr::AggregateFunction(AggregateFunction::new(
264-
aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
265-
vec![expr, weight_expr, percentile],
266-
false,
267-
None,
268-
None,
269-
None,
270-
))
271-
}
272-
273245
/// Create an EXISTS subquery expression
274246
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
275247
let outer_ref_columns = subquery.all_out_ref_exprs();

datafusion/expr/src/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ pub struct AccumulatorArgs<'a> {
8383
/// The input type of the aggregate function.
8484
pub input_type: &'a DataType,
8585

86-
/// The number of arguments the aggregate function takes.
87-
pub args_num: usize,
86+
/// The logical expression of arguments the aggregate function takes.
87+
pub input_exprs: &'a [Expr],
8888
}
8989

9090
/// [`StateFieldsArgs`] contains information about the fields that an

0 commit comments

Comments
 (0)