Skip to content

Commit c8e5996

Browse files
mertak-synnadametesynnadamustafasrepoMert Akkayaozankabak
authored
Remove redundant Aggregate when DISTINCT & GROUP BY are in the same query (#11781)
* Delete docs.yaml * initialize eliminate_aggregate.rs rule * remove redundant prints * Add multiple group by expression handling. * rename eliminate_aggregate.rs as eliminate_distinct.rs implement as rewrite function * remove logic for distinct on since group by statement must exist in projection * format code * add eliminate_distinct rule to tests * simplify function add additional tests for not removing cases * fix child issue * format * fix docs * remove eliminate_distinct rule and make it a part of replace_distinct_aggregate * Update datafusion/common/src/functional_dependencies.rs Co-authored-by: Mehmet Ozan Kabak <[email protected]> * add comment and fix variable call * fix test cases as optimized plan * format code * simplify comments Co-authored-by: Mehmet Ozan Kabak <[email protected]> * do not replace redundant distincts with aggregate --------- Co-authored-by: metesynnada <[email protected]> Co-authored-by: Mustafa Akur <[email protected]> Co-authored-by: Mustafa Akur <[email protected]> Co-authored-by: Mert Akkaya <[email protected]> Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent a4d41d6 commit c8e5996

File tree

4 files changed

+119
-30
lines changed

4 files changed

+119
-30
lines changed

datafusion/common/src/functional_dependencies.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -524,22 +524,31 @@ pub fn aggregate_functional_dependencies(
524524
}
525525
}
526526

527-
// If we have a single GROUP BY key, we can guarantee uniqueness after
527+
// When we have a GROUP BY key, we can guarantee uniqueness after
528528
// aggregation:
529-
if group_by_expr_names.len() == 1 {
530-
// If `source_indices` contain 0, delete this functional dependency
531-
// as it will be added anyway with mode `Dependency::Single`:
532-
aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0));
533-
// Add a new functional dependency associated with the whole table:
534-
aggregate_func_dependencies.push(
535-
// Use nullable property of the group by expression
536-
FunctionalDependence::new(
537-
vec![0],
538-
target_indices,
539-
aggr_fields[0].is_nullable(),
540-
)
541-
.with_mode(Dependency::Single),
542-
);
529+
if !group_by_expr_names.is_empty() {
530+
let count = group_by_expr_names.len();
531+
let source_indices = (0..count).collect::<Vec<_>>();
532+
let nullable = source_indices
533+
.iter()
534+
.any(|idx| aggr_fields[*idx].is_nullable());
535+
// If GROUP BY expressions do not already act as a determinant:
536+
if !aggregate_func_dependencies.iter().any(|item| {
537+
// If `item.source_indices` is a subset of GROUP BY expressions, we shouldn't add
538+
// them since `item.source_indices` defines this relation already.
539+
540+
// The following simple comparison is working well because
541+
// GROUP BY expressions come here as a prefix.
542+
item.source_indices.iter().all(|idx| idx < &count)
543+
}) {
544+
// Add a new functional dependency associated with the whole table:
545+
// Use nullable property of the GROUP BY expression:
546+
aggregate_func_dependencies.push(
547+
// Use nullable property of the GROUP BY expression:
548+
FunctionalDependence::new(source_indices, target_indices, nullable)
549+
.with_mode(Dependency::Single),
550+
);
551+
}
543552
}
544553
FunctionalDependencies::new(aggregate_func_dependencies)
545554
}

datafusion/optimizer/src/replace_distinct_aggregate.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
7777
match plan {
7878
LogicalPlan::Distinct(Distinct::All(input)) => {
7979
let group_expr = expand_wildcard(input.schema(), &input, None)?;
80+
81+
let field_count = input.schema().fields().len();
82+
for dep in input.schema().functional_dependencies().iter() {
83+
// If distinct is exactly the same with a previous GROUP BY, we can
84+
// simply remove it:
85+
if dep.source_indices[..field_count]
86+
.iter()
87+
.enumerate()
88+
.all(|(idx, f_idx)| idx == *f_idx)
89+
{
90+
return Ok(Transformed::yes(input.as_ref().clone()));
91+
}
92+
}
93+
94+
// Replace with aggregation:
8095
let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new(
8196
input,
8297
group_expr,
@@ -165,3 +180,78 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
165180
Some(BottomUp)
166181
}
167182
}
183+
184+
#[cfg(test)]
185+
mod tests {
186+
use std::sync::Arc;
187+
188+
use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
189+
use crate::test::*;
190+
191+
use datafusion_common::Result;
192+
use datafusion_expr::{
193+
col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
194+
};
195+
use datafusion_functions_aggregate::sum::sum;
196+
197+
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
198+
assert_optimized_plan_eq(
199+
Arc::new(ReplaceDistinctWithAggregate::new()),
200+
plan.clone(),
201+
expected,
202+
)
203+
}
204+
205+
#[test]
206+
fn eliminate_redundant_distinct_simple() -> Result<()> {
207+
let table_scan = test_table_scan().unwrap();
208+
let plan = LogicalPlanBuilder::from(table_scan)
209+
.aggregate(vec![col("c")], Vec::<Expr>::new())?
210+
.project(vec![col("c")])?
211+
.distinct()?
212+
.build()?;
213+
214+
let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test";
215+
assert_optimized_plan_equal(&plan, expected)
216+
}
217+
218+
#[test]
219+
fn eliminate_redundant_distinct_pair() -> Result<()> {
220+
let table_scan = test_table_scan().unwrap();
221+
let plan = LogicalPlanBuilder::from(table_scan)
222+
.aggregate(vec![col("a"), col("b")], Vec::<Expr>::new())?
223+
.project(vec![col("a"), col("b")])?
224+
.distinct()?
225+
.build()?;
226+
227+
let expected =
228+
"Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test";
229+
assert_optimized_plan_equal(&plan, expected)
230+
}
231+
232+
#[test]
233+
fn do_not_eliminate_distinct() -> Result<()> {
234+
let table_scan = test_table_scan().unwrap();
235+
let plan = LogicalPlanBuilder::from(table_scan)
236+
.project(vec![col("a"), col("b")])?
237+
.distinct()?
238+
.build()?;
239+
240+
let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test";
241+
assert_optimized_plan_equal(&plan, expected)
242+
}
243+
244+
#[test]
245+
fn do_not_eliminate_distinct_with_aggr() -> Result<()> {
246+
let table_scan = test_table_scan().unwrap();
247+
let plan = LogicalPlanBuilder::from(table_scan)
248+
.aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])?
249+
.project(vec![col("a"), col("b")])?
250+
.distinct()?
251+
.build()?;
252+
253+
let expected =
254+
"Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test";
255+
assert_optimized_plan_equal(&plan, expected)
256+
}
257+
}

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use hashbrown::HashSet;
3939
/// single distinct to group by optimizer rule
4040
/// ```text
4141
/// Before:
42-
/// SELECT a, count(DINSTINCT b), sum(c)
42+
/// SELECT a, count(DISTINCT b), sum(c)
4343
/// FROM t
4444
/// GROUP BY a
4545
///

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4536,19 +4536,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
45364536
logical_plan
45374537
01)Limit: skip=0, fetch=5
45384538
02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
4539-
03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
4540-
04)------TableScan: aggregate_test_100 projection=[c3]
4539+
03)----TableScan: aggregate_test_100 projection=[c3]
45414540
physical_plan
45424541
01)GlobalLimitExec: skip=0, fetch=5
45434542
02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5]
45444543
03)----CoalescePartitionsExec
45454544
04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5]
45464545
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4547-
06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5]
4548-
07)------------CoalescePartitionsExec
4549-
08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5]
4550-
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4551-
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
4546+
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
45524547

45534548
query I
45544549
SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
@@ -4699,19 +4694,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5;
46994694
logical_plan
47004695
01)Limit: skip=0, fetch=5
47014696
02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
4702-
03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]]
4703-
04)------TableScan: aggregate_test_100 projection=[c3]
4697+
03)----TableScan: aggregate_test_100 projection=[c3]
47044698
physical_plan
47054699
01)GlobalLimitExec: skip=0, fetch=5
47064700
02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[]
47074701
03)----CoalescePartitionsExec
47084702
04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[]
47094703
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4710-
06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[]
4711-
07)------------CoalescePartitionsExec
4712-
08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[]
4713-
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4714-
10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
4704+
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true
47154705

47164706
statement ok
47174707
set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true;

0 commit comments

Comments
 (0)