Skip to content

Commit bb53649

Browse files
committed
Implement grouping function using grouping id
This patch adds a Analyzer rule to transform the grouping aggreation function into computation ontop of the grouping id that is used internally for grouping sets.
1 parent 577e4bb commit bb53649

File tree

5 files changed

+461
-1
lines changed

5 files changed

+461
-1
lines changed

datafusion/functions-aggregate/src/grouping.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ impl Grouping {
5959
/// Create a new GROUPING aggregate function.
6060
pub fn new() -> Self {
6161
Self {
62-
signature: Signature::any(1, Volatility::Immutable),
62+
signature: Signature::variadic_any(Volatility::Immutable),
6363
}
6464
}
6565
}

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan};
3434
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
3535
use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
3636
use crate::analyzer::inline_table_scan::InlineTableScan;
37+
use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
3738
use crate::analyzer::subquery::check_subquery_expr;
3839
use crate::analyzer::type_coercion::TypeCoercion;
3940
use crate::utils::log_plan;
@@ -44,6 +45,7 @@ pub mod count_wildcard_rule;
4445
pub mod expand_wildcard_rule;
4546
pub mod function_rewrite;
4647
pub mod inline_table_scan;
48+
pub mod resolve_grouping_function;
4749
pub mod subquery;
4850
pub mod type_coercion;
4951

@@ -96,6 +98,7 @@ impl Analyzer {
9698
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
9799
Arc::new(ExpandWildcardRule::new()),
98100
// [Expr::Wildcard] should be expanded before [TypeCoercion]
101+
Arc::new(ResolveGroupingFunction::new()),
99102
Arc::new(TypeCoercion::new()),
100103
Arc::new(CountWildcardRule::new()),
101104
];
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Analyzed rule to replace TableScan references
19+
//! such as DataFrames and Views and inlines the LogicalPlan.
20+
21+
use std::cmp::Ordering;
22+
use std::collections::HashMap;
23+
use std::sync::Arc;
24+
25+
use crate::analyzer::AnalyzerRule;
26+
27+
use arrow::datatypes::DataType;
28+
use datafusion_common::config::ConfigOptions;
29+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30+
use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result, ScalarValue};
31+
use datafusion_expr::expr::{AggregateFunction, Alias};
32+
use datafusion_expr::logical_plan::LogicalPlan;
33+
use datafusion_expr::utils::grouping_set_to_exprlist;
34+
use datafusion_expr::{
35+
bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate,
36+
Expr, Projection,
37+
};
38+
use itertools::Itertools;
39+
40+
/// Replaces grouping aggregation function with value derived from internal grouping id
41+
#[derive(Default, Debug)]
42+
pub struct ResolveGroupingFunction;
43+
44+
impl ResolveGroupingFunction {
45+
pub fn new() -> Self {
46+
Self {}
47+
}
48+
}
49+
50+
impl AnalyzerRule for ResolveGroupingFunction {
51+
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
52+
plan.transform_up(analyze_internal).data()
53+
}
54+
55+
fn name(&self) -> &str {
56+
"resolve_grouping_function"
57+
}
58+
}
59+
60+
fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
61+
Ok(grouping_set_to_exprlist(group_expr)?
62+
.into_iter()
63+
.rev()
64+
.enumerate()
65+
.map(|(idx, v)| (v, idx))
66+
.collect::<HashMap<_, _>>())
67+
}
68+
69+
fn replace_grouping_exprs(
70+
input: Arc<LogicalPlan>,
71+
schema: DFSchemaRef,
72+
group_expr: Vec<Expr>,
73+
aggr_expr: Vec<Expr>,
74+
) -> Result<LogicalPlan> {
75+
// Create HashMap from Expr to index in the grouping_id bitmap
76+
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
77+
let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
78+
let columns = schema.columns();
79+
let mut new_agg_expr = Vec::new();
80+
let mut projection_exprs = Vec::new();
81+
let grouping_id_len = if is_grouping_set { 1 } else { 0 };
82+
let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
83+
projection_exprs.extend(
84+
columns
85+
.iter()
86+
.take(group_expr_len)
87+
.map(|column| Expr::Column(column.clone())),
88+
);
89+
for (expr, column) in aggr_expr
90+
.into_iter()
91+
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
92+
{
93+
match expr {
94+
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
95+
let grouping_expr = grouping_function_on_id(
96+
function,
97+
&group_expr_to_bitmap_index,
98+
is_grouping_set,
99+
)?;
100+
projection_exprs.push(Expr::Alias(Alias::new(
101+
grouping_expr,
102+
column.relation,
103+
column.name,
104+
)));
105+
}
106+
_ => {
107+
projection_exprs.push(Expr::Column(column));
108+
new_agg_expr.push(expr);
109+
}
110+
}
111+
}
112+
// Recreate aggregate without grouping functions
113+
let new_aggregate =
114+
LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
115+
// Create projection with grouping functions calculations
116+
let projection = LogicalPlan::Projection(Projection::try_new(
117+
projection_exprs,
118+
new_aggregate.into(),
119+
)?);
120+
Ok(projection)
121+
}
122+
123+
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
124+
// rewrite any subqueries in the plan first
125+
let transformed_plan =
126+
plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
127+
128+
let transformed_plan = transformed_plan.transform_data(|plan| match plan {
129+
LogicalPlan::Aggregate(Aggregate {
130+
input,
131+
group_expr,
132+
aggr_expr,
133+
schema,
134+
..
135+
}) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
136+
replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
137+
)),
138+
_ => Ok(Transformed::no(plan)),
139+
})?;
140+
141+
Ok(transformed_plan)
142+
}
143+
144+
fn is_grouping_function(expr: &Expr) -> bool {
145+
// TODO: Do something better than name here should grouping be a built
146+
// in expression?
147+
matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping")
148+
}
149+
150+
fn contains_grouping_function(exprs: &[Expr]) -> bool {
151+
exprs.iter().any(is_grouping_function)
152+
}
153+
154+
fn validate_args(
155+
function: &AggregateFunction,
156+
group_by_expr: &HashMap<&Expr, usize>,
157+
) -> Result<()> {
158+
let expr_not_in_group_by = function
159+
.args
160+
.iter()
161+
.find(|expr| !group_by_expr.contains_key(expr));
162+
if let Some(expr) = expr_not_in_group_by {
163+
Err(DataFusionError::Plan(format!(
164+
"Argument {} to grouping function is not in grouping columns {}",
165+
expr,
166+
group_by_expr.keys().map(|e| e.to_string()).join(", ")
167+
)))
168+
} else {
169+
Ok(())
170+
}
171+
}
172+
173+
fn grouping_function_on_id(
174+
function: &AggregateFunction,
175+
group_by_expr: &HashMap<&Expr, usize>,
176+
is_grouping_set: bool,
177+
) -> Result<Expr> {
178+
validate_args(function, group_by_expr)?;
179+
let args = &function.args;
180+
181+
// Postgres allows grouping function for group by without grouping sets, the result is then
182+
// always 0
183+
if !is_grouping_set {
184+
return Ok(Expr::Literal(ScalarValue::from(0u32)));
185+
}
186+
187+
let group_by_expr_count = group_by_expr.len();
188+
let literal = |value: usize| {
189+
if group_by_expr_count < 8 {
190+
Expr::Literal(ScalarValue::from(value as u8))
191+
} else if group_by_expr_count < 16 {
192+
Expr::Literal(ScalarValue::from(value as u16))
193+
} else if group_by_expr_count < 32 {
194+
Expr::Literal(ScalarValue::from(value as u32))
195+
} else {
196+
Expr::Literal(ScalarValue::from(value as u64))
197+
}
198+
};
199+
200+
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
201+
// The grouping call is exactly our internal grouping id
202+
if args.len() == group_by_expr_count
203+
&& args
204+
.iter()
205+
.rev()
206+
.enumerate()
207+
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
208+
{
209+
return Ok(cast(grouping_id_column, DataType::UInt32));
210+
}
211+
212+
args.iter()
213+
.rev()
214+
.enumerate()
215+
.map(|(arg_idx, expr)| {
216+
group_by_expr.get(expr).map(|group_by_idx| {
217+
let group_by_bit =
218+
bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
219+
match group_by_idx.cmp(&arg_idx) {
220+
Ordering::Less => {
221+
bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
222+
}
223+
Ordering::Greater => {
224+
bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
225+
}
226+
Ordering::Equal => group_by_bit,
227+
}
228+
})
229+
})
230+
.collect::<Option<Vec<_>>>()
231+
.and_then(|bit_exprs| {
232+
bit_exprs
233+
.into_iter()
234+
.reduce(bitwise_or)
235+
.map(|expr| cast(expr, DataType::UInt32))
236+
})
237+
.ok_or_else(|| {
238+
DataFusionError::Internal(
239+
"Grouping sets should contains at least one element".to_string(),
240+
)
241+
})
242+
}

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ initial_logical_plan
176176
02)--TableScan: simple_explain_test
177177
logical_plan after inline_table_scan SAME TEXT AS ABOVE
178178
logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE
179+
logical_plan after resolve_grouping_function SAME TEXT AS ABOVE
179180
logical_plan after type_coercion SAME TEXT AS ABOVE
180181
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
181182
analyzed_logical_plan SAME TEXT AS ABOVE

0 commit comments

Comments
 (0)