Skip to content

Commit cc60278

Browse files
Emil Ejbyfeldtalamb
andauthored
Move Regr_* functions to use UDAF (#10898)
* Move Regr_* functions to use UDAF Closes #10883 and is part of #8708 * Format and regen * tweak error check --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent b627ca3 commit cc60278

File tree

15 files changed

+135
-316
lines changed

15 files changed

+135
-316
lines changed

datafusion/expr/src/aggregate_function.rs

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,6 @@ pub enum AggregateFunction {
4545
NthValue,
4646
/// Correlation
4747
Correlation,
48-
/// Slope from linear regression
49-
RegrSlope,
50-
/// Intercept from linear regression
51-
RegrIntercept,
52-
/// Number of input rows in which both expressions are not null
53-
RegrCount,
54-
/// R-squared value from linear regression
55-
RegrR2,
56-
/// Average of the independent variable
57-
RegrAvgx,
58-
/// Average of the dependent variable
59-
RegrAvgy,
60-
/// Sum of squares of the independent variable
61-
RegrSXX,
62-
/// Sum of squares of the dependent variable
63-
RegrSYY,
64-
/// Sum of products of pairs of numbers
65-
RegrSXY,
6648
/// Approximate continuous percentile function
6749
ApproxPercentileCont,
6850
/// Approximate continuous percentile function with weight
@@ -93,15 +75,6 @@ impl AggregateFunction {
9375
ArrayAgg => "ARRAY_AGG",
9476
NthValue => "NTH_VALUE",
9577
Correlation => "CORR",
96-
RegrSlope => "REGR_SLOPE",
97-
RegrIntercept => "REGR_INTERCEPT",
98-
RegrCount => "REGR_COUNT",
99-
RegrR2 => "REGR_R2",
100-
RegrAvgx => "REGR_AVGX",
101-
RegrAvgy => "REGR_AVGY",
102-
RegrSXX => "REGR_SXX",
103-
RegrSYY => "REGR_SYY",
104-
RegrSXY => "REGR_SXY",
10578
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
10679
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
10780
Grouping => "GROUPING",
@@ -140,15 +113,6 @@ impl FromStr for AggregateFunction {
140113
"string_agg" => AggregateFunction::StringAgg,
141114
// statistical
142115
"corr" => AggregateFunction::Correlation,
143-
"regr_slope" => AggregateFunction::RegrSlope,
144-
"regr_intercept" => AggregateFunction::RegrIntercept,
145-
"regr_count" => AggregateFunction::RegrCount,
146-
"regr_r2" => AggregateFunction::RegrR2,
147-
"regr_avgx" => AggregateFunction::RegrAvgx,
148-
"regr_avgy" => AggregateFunction::RegrAvgy,
149-
"regr_sxx" => AggregateFunction::RegrSXX,
150-
"regr_syy" => AggregateFunction::RegrSYY,
151-
"regr_sxy" => AggregateFunction::RegrSXY,
152116
// approximate
153117
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
154118
"approx_percentile_cont_with_weight" => {
@@ -200,15 +164,6 @@ impl AggregateFunction {
200164
AggregateFunction::Correlation => {
201165
correlation_return_type(&coerced_data_types[0])
202166
}
203-
AggregateFunction::RegrSlope
204-
| AggregateFunction::RegrIntercept
205-
| AggregateFunction::RegrCount
206-
| AggregateFunction::RegrR2
207-
| AggregateFunction::RegrAvgx
208-
| AggregateFunction::RegrAvgy
209-
| AggregateFunction::RegrSXX
210-
| AggregateFunction::RegrSYY
211-
| AggregateFunction::RegrSXY => Ok(DataType::Float64),
212167
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
213168
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
214169
"item",
@@ -272,16 +227,7 @@ impl AggregateFunction {
272227
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
273228
}
274229
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
275-
AggregateFunction::Correlation
276-
| AggregateFunction::RegrSlope
277-
| AggregateFunction::RegrIntercept
278-
| AggregateFunction::RegrCount
279-
| AggregateFunction::RegrR2
280-
| AggregateFunction::RegrAvgx
281-
| AggregateFunction::RegrAvgy
282-
| AggregateFunction::RegrSXX
283-
| AggregateFunction::RegrSYY
284-
| AggregateFunction::RegrSXY => {
230+
AggregateFunction::Correlation => {
285231
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
286232
}
287233
AggregateFunction::ApproxPercentileCont => {

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,6 @@ pub fn coerce_types(
158158
}
159159
Ok(vec![Float64, Float64])
160160
}
161-
AggregateFunction::RegrSlope
162-
| AggregateFunction::RegrIntercept
163-
| AggregateFunction::RegrCount
164-
| AggregateFunction::RegrR2
165-
| AggregateFunction::RegrAvgx
166-
| AggregateFunction::RegrAvgy
167-
| AggregateFunction::RegrSXX
168-
| AggregateFunction::RegrSYY
169-
| AggregateFunction::RegrSXY => {
170-
let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
171-
let input_types_valid = // number of input already checked before
172-
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
173-
if !input_types_valid {
174-
return plan_err!(
175-
"The function {:?} does not support inputs of type {:?}.",
176-
agg_fun,
177-
input_types[0]
178-
);
179-
}
180-
Ok(vec![Float64, Float64])
181-
}
182161
AggregateFunction::ApproxPercentileCont => {
183162
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
184163
return plan_err!(

datafusion/functions-aggregate/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ pub mod covariance;
6161
pub mod first_last;
6262
pub mod hyperloglog;
6363
pub mod median;
64+
pub mod regr;
6465
pub mod stddev;
6566
pub mod sum;
6667
pub mod variance;
@@ -85,6 +86,15 @@ pub mod expr_fn {
8586
pub use super::first_last::first_value;
8687
pub use super::first_last::last_value;
8788
pub use super::median::median;
89+
pub use super::regr::regr_avgx;
90+
pub use super::regr::regr_avgy;
91+
pub use super::regr::regr_count;
92+
pub use super::regr::regr_intercept;
93+
pub use super::regr::regr_r2;
94+
pub use super::regr::regr_slope;
95+
pub use super::regr::regr_sxx;
96+
pub use super::regr::regr_sxy;
97+
pub use super::regr::regr_syy;
8898
pub use super::stddev::stddev;
8999
pub use super::stddev::stddev_pop;
90100
pub use super::sum::sum;
@@ -102,6 +112,15 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
102112
covariance::covar_pop_udaf(),
103113
median::median_udaf(),
104114
count::count_udaf(),
115+
regr::regr_slope_udaf(),
116+
regr::regr_intercept_udaf(),
117+
regr::regr_count_udaf(),
118+
regr::regr_r2_udaf(),
119+
regr::regr_avgx_udaf(),
120+
regr::regr_avgy_udaf(),
121+
regr::regr_sxx_udaf(),
122+
regr::regr_syy_udaf(),
123+
regr::regr_sxy_udaf(),
105124
variance::var_samp_udaf(),
106125
variance::var_pop_udaf(),
107126
stddev::stddev_udaf(),

datafusion/functions-aggregate/src/macros.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
// specific language governing permissions and limitations
3333
// under the License.
3434

35-
macro_rules! make_udaf_expr_and_func {
36-
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
35+
macro_rules! make_udaf_expr {
36+
($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
3737
// "fluent expr_fn" style function
3838
#[doc = $DOC]
3939
pub fn $EXPR_FN(
@@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func {
4848
None,
4949
))
5050
}
51+
};
52+
}
5153

54+
macro_rules! make_udaf_expr_and_func {
55+
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
56+
make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN);
5257
create_func!($UDAF, $AGGREGATE_UDF_FN);
5358
};
5459
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
@@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func {
7378

7479
macro_rules! create_func {
7580
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
81+
create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default());
82+
};
83+
($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => {
7684
paste::paste! {
7785
/// Singleton instance of [$UDAF], ensures the UDAF is only created once
7886
/// named STATIC_$(UDAF). For example `STATIC_FirstValue`
@@ -86,7 +94,7 @@ macro_rules! create_func {
8694
pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<datafusion_expr::AggregateUDF> {
8795
[< STATIC_ $UDAF >]
8896
.get_or_init(|| {
89-
std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default()))
97+
std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE))
9098
})
9199
.clone()
92100
}

0 commit comments

Comments
 (0)