Skip to content

Commit 3ca80c1

Browse files
tshauckalamb
authored andcommitted
feat: support COUNT() (apache#11229)
* feat: add count empty rewrite * feat: make count support zero args * docs: add apache license * tests: make count() valid * tests: more tests * refactor: sketch `AggregateFunctionPlanner` * refactor: cleanup `AggregateFunctionPlanner` * feat: add back rule * Revert "feat: add back rule" This reverts commit 2c4fc0a. * Revert "refactor: cleanup `AggregateFunctionPlanner`" This reverts commit 4550dbd. * Revert "refactor: sketch `AggregateFunctionPlanner`" This reverts commit 658671e. * Apply suggestions from code review Co-authored-by: Andrew Lamb <[email protected]> * refactor: PR feedback * style: fix indent --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent e174011 commit 3ca80c1

File tree

4 files changed

+114
-9
lines changed

4 files changed

+114
-9
lines changed

datafusion/functions-aggregate/src/count.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use datafusion_expr::{
4444
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
4545
EmitTo, GroupsAccumulator, Signature, Volatility,
4646
};
47-
use datafusion_expr::{Expr, ReversedUDAF};
47+
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
4848
use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
4949
use datafusion_physical_expr_common::{
5050
aggregate::count_distinct::{
@@ -95,7 +95,11 @@ impl Default for Count {
9595
impl Count {
9696
pub fn new() -> Self {
9797
Self {
98-
signature: Signature::variadic_any(Volatility::Immutable),
98+
signature: Signature::one_of(
99+
// TypeSignature::Any(0) is required to handle `Count()` with no args
100+
vec![TypeSignature::VariadicAny, TypeSignature::Any(0)],
101+
Volatility::Immutable,
102+
),
99103
}
100104
}
101105
}

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct CountWildcardRule {}
3535

3636
impl CountWildcardRule {
3737
pub fn new() -> Self {
38-
CountWildcardRule {}
38+
Self {}
3939
}
4040
}
4141

@@ -59,14 +59,14 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
5959
func_def: AggregateFunctionDefinition::UDF(udf),
6060
args,
6161
..
62-
} if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
62+
} if udf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
6363
}
6464

6565
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
6666
let args = &window_function.args;
6767
matches!(window_function.fun,
6868
WindowFunctionDefinition::AggregateUDF(ref udaf)
69-
if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0]))
69+
if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
7070
}
7171

7272
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
statement ok
19+
CREATE TABLE t1 (a INTEGER, b INTEGER, c INTEGER);
20+
21+
statement ok
22+
INSERT INTO t1 VALUES
23+
(1, 2, 3),
24+
(1, 5, 6),
25+
(2, 3, 5);
26+
27+
statement ok
28+
CREATE TABLE t2 (a INTEGER, b INTEGER, c INTEGER);
29+
30+
query TT
31+
EXPLAIN SELECT COUNT() FROM (SELECT 1 AS a, 2 AS b) AS t;
32+
----
33+
logical_plan
34+
01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]
35+
02)--SubqueryAlias: t
36+
03)----EmptyRelation
37+
physical_plan
38+
01)ProjectionExec: expr=[1 as count()]
39+
02)--PlaceholderRowExec
40+
41+
query TT
42+
EXPLAIN SELECT t1.a, COUNT() FROM t1 GROUP BY t1.a;
43+
----
44+
logical_plan
45+
01)Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]]
46+
02)--TableScan: t1 projection=[a]
47+
physical_plan
48+
01)AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()]
49+
02)--CoalesceBatchesExec: target_batch_size=8192
50+
03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
51+
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
52+
05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()]
53+
06)----------MemoryExec: partitions=1, partition_sizes=[1]
54+
55+
query TT
56+
EXPLAIN SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 0;
57+
----
58+
logical_plan
59+
01)Projection: t1.a, count() AS cnt
60+
02)--Filter: count() > Int64(0)
61+
03)----Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]]
62+
04)------TableScan: t1 projection=[a]
63+
physical_plan
64+
01)ProjectionExec: expr=[a@0 as a, count()@1 as cnt]
65+
02)--CoalesceBatchesExec: target_batch_size=8192
66+
03)----FilterExec: count()@1 > 0
67+
04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()]
68+
05)--------CoalesceBatchesExec: target_batch_size=8192
69+
06)----------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
70+
07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
71+
08)--------------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()]
72+
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
73+
74+
query II
75+
SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 1;
76+
----
77+
1 2
78+
79+
query TT
80+
EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1;
81+
----
82+
logical_plan
83+
01)Projection: t1.a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a
84+
02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
85+
03)----TableScan: t1 projection=[a]
86+
physical_plan
87+
01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a]
88+
02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]
89+
03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]
90+
04)------CoalesceBatchesExec: target_batch_size=8192
91+
05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1
92+
06)----------MemoryExec: partitions=1, partition_sizes=[1]
93+
94+
query II
95+
SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a;
96+
----
97+
1 2
98+
1 2
99+
2 1
100+
101+
statement ok
102+
DROP TABLE t1;
103+
104+
statement ok
105+
DROP TABLE t2;

datafusion/sqllogictest/test_files/errors.slt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ SELECT power(1, 2, 3);
103103
# Wrong window/aggregate function signature
104104
#
105105

106-
# AggregateFunction with wrong number of arguments
107-
query error
108-
select count();
109-
110106
# AggregateFunction with wrong number of arguments
111107
query error
112108
select avg(c1, c12) from aggregate_test_100;

0 commit comments

Comments
 (0)