Skip to content

Commit 8802f63

Browse files
authored
feat: Implement grouping function using grouping id (#12704)
* Implement grouping function using grouping id This patch adds a Analyzer rule to transform the grouping aggreation function into computation ontop of the grouping id that is used internally for grouping sets. * PR comments
1 parent 8d46fc1 commit 8802f63

File tree

5 files changed

+466
-1
lines changed

5 files changed

+466
-1
lines changed

datafusion/functions-aggregate/src/grouping.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl Grouping {
6363
/// Create a new GROUPING aggregate function.
6464
pub fn new() -> Self {
6565
Self {
66-
signature: Signature::any(1, Volatility::Immutable),
66+
signature: Signature::variadic_any(Volatility::Immutable),
6767
}
6868
}
6969
}

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan};
3434
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
3535
use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
3636
use crate::analyzer::inline_table_scan::InlineTableScan;
37+
use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
3738
use crate::analyzer::subquery::check_subquery_expr;
3839
use crate::analyzer::type_coercion::TypeCoercion;
3940
use crate::utils::log_plan;
@@ -44,6 +45,7 @@ pub mod count_wildcard_rule;
4445
pub mod expand_wildcard_rule;
4546
pub mod function_rewrite;
4647
pub mod inline_table_scan;
48+
pub mod resolve_grouping_function;
4749
pub mod subquery;
4850
pub mod type_coercion;
4951

@@ -96,6 +98,7 @@ impl Analyzer {
9698
// Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule].
9799
Arc::new(ExpandWildcardRule::new()),
98100
// [Expr::Wildcard] should be expanded before [TypeCoercion]
101+
Arc::new(ResolveGroupingFunction::new()),
99102
Arc::new(TypeCoercion::new()),
100103
Arc::new(CountWildcardRule::new()),
101104
];
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Analyzed rule to replace TableScan references
19+
//! such as DataFrames and Views and inlines the LogicalPlan.
20+
21+
use std::cmp::Ordering;
22+
use std::collections::HashMap;
23+
use std::sync::Arc;
24+
25+
use crate::analyzer::AnalyzerRule;
26+
27+
use arrow::datatypes::DataType;
28+
use datafusion_common::config::ConfigOptions;
29+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30+
use datafusion_common::{
31+
internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue,
32+
};
33+
use datafusion_expr::expr::{AggregateFunction, Alias};
34+
use datafusion_expr::logical_plan::LogicalPlan;
35+
use datafusion_expr::utils::grouping_set_to_exprlist;
36+
use datafusion_expr::{
37+
bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate,
38+
Expr, Projection,
39+
};
40+
use itertools::Itertools;
41+
42+
/// Replaces grouping aggregation function with value derived from internal grouping id
43+
#[derive(Default, Debug)]
44+
pub struct ResolveGroupingFunction;
45+
46+
impl ResolveGroupingFunction {
47+
pub fn new() -> Self {
48+
Self {}
49+
}
50+
}
51+
52+
impl AnalyzerRule for ResolveGroupingFunction {
53+
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
54+
plan.transform_up(analyze_internal).data()
55+
}
56+
57+
fn name(&self) -> &str {
58+
"resolve_grouping_function"
59+
}
60+
}
61+
62+
/// Create a map from grouping expr to index in the internal grouping id.
63+
///
64+
/// For more details on how the grouping id bitmap works the documentation for
65+
/// [[Aggregate::INTERNAL_GROUPING_ID]]
66+
fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
67+
Ok(grouping_set_to_exprlist(group_expr)?
68+
.into_iter()
69+
.rev()
70+
.enumerate()
71+
.map(|(idx, v)| (v, idx))
72+
.collect::<HashMap<_, _>>())
73+
}
74+
75+
fn replace_grouping_exprs(
76+
input: Arc<LogicalPlan>,
77+
schema: DFSchemaRef,
78+
group_expr: Vec<Expr>,
79+
aggr_expr: Vec<Expr>,
80+
) -> Result<LogicalPlan> {
81+
// Create HashMap from Expr to index in the grouping_id bitmap
82+
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
83+
let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
84+
let columns = schema.columns();
85+
let mut new_agg_expr = Vec::new();
86+
let mut projection_exprs = Vec::new();
87+
let grouping_id_len = if is_grouping_set { 1 } else { 0 };
88+
let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
89+
projection_exprs.extend(
90+
columns
91+
.iter()
92+
.take(group_expr_len)
93+
.map(|column| Expr::Column(column.clone())),
94+
);
95+
for (expr, column) in aggr_expr
96+
.into_iter()
97+
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
98+
{
99+
match expr {
100+
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
101+
let grouping_expr = grouping_function_on_id(
102+
function,
103+
&group_expr_to_bitmap_index,
104+
is_grouping_set,
105+
)?;
106+
projection_exprs.push(Expr::Alias(Alias::new(
107+
grouping_expr,
108+
column.relation,
109+
column.name,
110+
)));
111+
}
112+
_ => {
113+
projection_exprs.push(Expr::Column(column));
114+
new_agg_expr.push(expr);
115+
}
116+
}
117+
}
118+
// Recreate aggregate without grouping functions
119+
let new_aggregate =
120+
LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?);
121+
// Create projection with grouping functions calculations
122+
let projection = LogicalPlan::Projection(Projection::try_new(
123+
projection_exprs,
124+
new_aggregate.into(),
125+
)?);
126+
Ok(projection)
127+
}
128+
129+
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
130+
// rewrite any subqueries in the plan first
131+
let transformed_plan =
132+
plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
133+
134+
let transformed_plan = transformed_plan.transform_data(|plan| match plan {
135+
LogicalPlan::Aggregate(Aggregate {
136+
input,
137+
group_expr,
138+
aggr_expr,
139+
schema,
140+
..
141+
}) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
142+
replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
143+
)),
144+
_ => Ok(Transformed::no(plan)),
145+
})?;
146+
147+
Ok(transformed_plan)
148+
}
149+
150+
fn is_grouping_function(expr: &Expr) -> bool {
151+
// TODO: Do something better than name here should grouping be a built
152+
// in expression?
153+
matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping")
154+
}
155+
156+
fn contains_grouping_function(exprs: &[Expr]) -> bool {
157+
exprs.iter().any(is_grouping_function)
158+
}
159+
160+
/// Validate that the arguments to the grouping function are in the group by clause.
161+
fn validate_args(
162+
function: &AggregateFunction,
163+
group_by_expr: &HashMap<&Expr, usize>,
164+
) -> Result<()> {
165+
let expr_not_in_group_by = function
166+
.args
167+
.iter()
168+
.find(|expr| !group_by_expr.contains_key(expr));
169+
if let Some(expr) = expr_not_in_group_by {
170+
plan_err!(
171+
"Argument {} to grouping function is not in grouping columns {}",
172+
expr,
173+
group_by_expr.keys().map(|e| e.to_string()).join(", ")
174+
)
175+
} else {
176+
Ok(())
177+
}
178+
}
179+
180+
fn grouping_function_on_id(
181+
function: &AggregateFunction,
182+
group_by_expr: &HashMap<&Expr, usize>,
183+
is_grouping_set: bool,
184+
) -> Result<Expr> {
185+
validate_args(function, group_by_expr)?;
186+
let args = &function.args;
187+
188+
// Postgres allows grouping function for group by without grouping sets, the result is then
189+
// always 0
190+
if !is_grouping_set {
191+
return Ok(Expr::Literal(ScalarValue::from(0i32)));
192+
}
193+
194+
let group_by_expr_count = group_by_expr.len();
195+
let literal = |value: usize| {
196+
if group_by_expr_count < 8 {
197+
Expr::Literal(ScalarValue::from(value as u8))
198+
} else if group_by_expr_count < 16 {
199+
Expr::Literal(ScalarValue::from(value as u16))
200+
} else if group_by_expr_count < 32 {
201+
Expr::Literal(ScalarValue::from(value as u32))
202+
} else {
203+
Expr::Literal(ScalarValue::from(value as u64))
204+
}
205+
};
206+
207+
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
208+
// The grouping call is exactly our internal grouping id
209+
if args.len() == group_by_expr_count
210+
&& args
211+
.iter()
212+
.rev()
213+
.enumerate()
214+
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
215+
{
216+
return Ok(cast(grouping_id_column, DataType::Int32));
217+
}
218+
219+
args.iter()
220+
.rev()
221+
.enumerate()
222+
.map(|(arg_idx, expr)| {
223+
group_by_expr.get(expr).map(|group_by_idx| {
224+
let group_by_bit =
225+
bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx));
226+
match group_by_idx.cmp(&arg_idx) {
227+
Ordering::Less => {
228+
bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx))
229+
}
230+
Ordering::Greater => {
231+
bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx))
232+
}
233+
Ordering::Equal => group_by_bit,
234+
}
235+
})
236+
})
237+
.collect::<Option<Vec<_>>>()
238+
.and_then(|bit_exprs| {
239+
bit_exprs
240+
.into_iter()
241+
.reduce(bitwise_or)
242+
.map(|expr| cast(expr, DataType::Int32))
243+
})
244+
.ok_or_else(|| {
245+
internal_datafusion_err!("Grouping sets should contains at least one element")
246+
})
247+
}

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ initial_logical_plan
176176
02)--TableScan: simple_explain_test
177177
logical_plan after inline_table_scan SAME TEXT AS ABOVE
178178
logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE
179+
logical_plan after resolve_grouping_function SAME TEXT AS ABOVE
179180
logical_plan after type_coercion SAME TEXT AS ABOVE
180181
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
181182
analyzed_logical_plan SAME TEXT AS ABOVE

0 commit comments

Comments
 (0)