Skip to content

Commit 658671e

Browse files
committed
refactor: sketch AggregateFunctionPlanner
1 parent 10f32a4 commit 658671e

File tree

10 files changed

+92
-106
lines changed

10 files changed

+92
-106
lines changed

datafusion/core/src/execution/session_state.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ impl SessionState {
242242
Arc::new(functions::datetime::planner::ExtractPlanner),
243243
#[cfg(feature = "unicode_expressions")]
244244
Arc::new(functions::unicode::planner::PositionPlanner),
245+
Arc::new(
246+
functions_aggregate::aggregate_function_planner::AggregateFunctionPlanner,
247+
),
245248
];
246249

247250
let mut new_self = SessionState {

datafusion/expr/src/planner.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion_common::{
2424
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
2525
Result, TableReference,
2626
};
27+
use sqlparser::ast::NullTreatment;
2728

2829
use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};
2930

@@ -107,7 +108,7 @@ pub trait UserDefinedSQLPlanner: Send + Sync {
107108

108109
/// Plan the array literal, returns OriginalArray if not possible
109110
///
110-
/// Returns origin expression arguments if not possible
111+
/// Returns original expression arguments if not possible
111112
fn plan_array_literal(
112113
&self,
113114
exprs: Vec<Expr>,
@@ -124,7 +125,7 @@ pub trait UserDefinedSQLPlanner: Send + Sync {
124125

125126
/// Plan the dictionary literal `{ key: value, ...}`
126127
///
127-
/// Returns origin expression arguments if not possible
128+
/// Returns original expression arguments if not possible
128129
fn plan_dictionary_literal(
129130
&self,
130131
expr: RawDictionaryExpr,
@@ -135,10 +136,20 @@ pub trait UserDefinedSQLPlanner: Send + Sync {
135136

136137
/// Plan an extract expression, e.g., `EXTRACT(month FROM foo)`
137138
///
138-
/// Returns origin expression arguments if not possible
139+
/// Returns original expression arguments if not possible
139140
fn plan_extract(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
140141
Ok(PlannerResult::Original(args))
141142
}
143+
144+
/// Plan an aggregate function, e.g., `SUM(foo)`
145+
///
146+
/// Returns original expression arguments if not possible
147+
fn plan_aggregate_function(
148+
&self,
149+
aggregate_function: RawAggregateFunction,
150+
) -> Result<PlannerResult<RawAggregateFunction>> {
151+
Ok(PlannerResult::Original(aggregate_function))
152+
}
142153
}
143154

144155
/// An operator with two arguments to plan
@@ -183,3 +194,14 @@ pub enum PlannerResult<T> {
183194
/// The raw expression could not be planned, and is returned unmodified
184195
Original(T),
185196
}
197+
198+
// An aggregate function to plan.
199+
#[derive(Debug, Clone)]
200+
pub struct RawAggregateFunction {
201+
pub udf: Arc<crate::AggregateUDF>,
202+
pub args: Vec<Expr>,
203+
pub distinct: bool,
204+
pub filter: Option<Box<Expr>>,
205+
pub order_by: Option<Vec<Expr>>,
206+
pub null_treatment: Option<NullTreatment>,
207+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use datafusion_expr::{
2+
expr, lit,
3+
planner::{PlannerResult, RawAggregateFunction, UserDefinedSQLPlanner},
4+
Expr,
5+
};
6+
7+
pub struct AggregateFunctionPlanner;
8+
9+
impl UserDefinedSQLPlanner for AggregateFunctionPlanner {
10+
fn plan_aggregate_function(
11+
&self,
12+
aggregate_function: RawAggregateFunction,
13+
) -> datafusion_common::Result<PlannerResult<RawAggregateFunction>> {
14+
let RawAggregateFunction {
15+
udf,
16+
args,
17+
distinct,
18+
filter,
19+
order_by,
20+
null_treatment,
21+
} = aggregate_function.clone();
22+
23+
if udf.name() == "count" && args.is_empty() {
24+
return Ok(PlannerResult::Planned(Expr::AggregateFunction(
25+
expr::AggregateFunction::new_udf(
26+
udf.clone(),
27+
vec![lit(1).alias("")],
28+
distinct,
29+
filter.clone(),
30+
order_by.clone(),
31+
null_treatment.clone(),
32+
),
33+
)));
34+
}
35+
36+
Ok(PlannerResult::Original(aggregate_function.clone()))
37+
}
38+
}

datafusion/functions-aggregate/src/count.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use datafusion_expr::{
4444
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
4545
EmitTo, GroupsAccumulator, Signature, Volatility,
4646
};
47-
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
47+
use datafusion_expr::{Expr, ReversedUDAF};
4848
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
4949
use datafusion_physical_expr_common::{
5050
aggregate::count_distinct::{
@@ -95,10 +95,7 @@ impl Default for Count {
9595
impl Count {
9696
pub fn new() -> Self {
9797
Self {
98-
signature: Signature::one_of(
99-
vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
100-
Volatility::Immutable,
101-
),
98+
signature: Signature::variadic_any(Volatility::Immutable),
10299
}
103100
}
104101
}

datafusion/functions-aggregate/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ use datafusion_expr::AggregateUDF;
8484
use log::debug;
8585
use std::sync::Arc;
8686

87+
pub mod aggregate_function_planner;
88+
8789
/// Fluent-style API for creating `Expr`s
8890
pub mod expr_fn {
8991
pub use super::approx_distinct;

datafusion/optimizer/src/analyzer/count_empty_rule.rs

Lines changed: 0 additions & 91 deletions
This file was deleted.

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct CountWildcardRule {}
3535

3636
impl CountWildcardRule {
3737
pub fn new() -> Self {
38-
Self {}
38+
CountWildcardRule {}
3939
}
4040
}
4141

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use datafusion_expr::expr::InSubquery;
2929
use datafusion_expr::expr_rewriter::FunctionRewrite;
3030
use datafusion_expr::{Expr, LogicalPlan};
3131

32-
use crate::analyzer::count_empty_rule::CountEmptyRule;
3332
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
3433
use crate::analyzer::inline_table_scan::InlineTableScan;
3534
use crate::analyzer::subquery::check_subquery_expr;
@@ -38,7 +37,6 @@ use crate::utils::log_plan;
3837

3938
use self::function_rewrite::ApplyFunctionRewrites;
4039

41-
pub mod count_empty_rule;
4240
pub mod count_wildcard_rule;
4341
pub mod function_rewrite;
4442
pub mod inline_table_scan;
@@ -93,7 +91,6 @@ impl Analyzer {
9391
Arc::new(InlineTableScan::new()),
9492
Arc::new(TypeCoercion::new()),
9593
Arc::new(CountWildcardRule::new()),
96-
Arc::new(CountEmptyRule::new()),
9794
];
9895
Self::with_rules(rules)
9996
}

datafusion/sql/src/expr/function.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use datafusion_common::{
2121
internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
2222
Dependency, Result,
2323
};
24+
use datafusion_expr::planner::{PlannerResult, RawAggregateFunction};
2425
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
2526
use datafusion_expr::{
2627
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
@@ -349,13 +350,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
349350
.map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context))
350351
.transpose()?
351352
.map(Box::new);
352-
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
353-
fm,
353+
354+
let raw_aggregate_function = RawAggregateFunction {
355+
udf: fm,
354356
args,
355357
distinct,
356358
filter,
357359
order_by,
358360
null_treatment,
361+
};
362+
363+
for planner in self.planners.iter() {
364+
if let PlannerResult::Planned(aggregate_function) =
365+
planner.plan_aggregate_function(raw_aggregate_function.clone())?
366+
{
367+
return Ok(aggregate_function);
368+
}
369+
}
370+
371+
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
372+
raw_aggregate_function.udf,
373+
raw_aggregate_function.args,
374+
distinct,
375+
raw_aggregate_function.filter,
376+
raw_aggregate_function.order_by,
377+
null_treatment,
359378
)));
360379
}
361380

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ logical_plan after apply_function_rewrites SAME TEXT AS ABOVE
183183
logical_plan after inline_table_scan SAME TEXT AS ABOVE
184184
logical_plan after type_coercion SAME TEXT AS ABOVE
185185
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
186-
logical_plan after count_empty_rule SAME TEXT AS ABOVE
187186
analyzed_logical_plan SAME TEXT AS ABOVE
188187
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
189188
logical_plan after simplify_expressions SAME TEXT AS ABOVE

0 commit comments

Comments
 (0)