Skip to content

Commit 80baff7

Browse files
Jefffreyqazxcdswe123alambblaginin
authored andcommitted
avg(distinct) support for decimal types (apache#17560)
* chore: mv `DistinctSumAccumulator` to common * feat: add avg distinct support for float64 type * chore: fmt * refactor: update import for DataType in Float64DistinctAvgAccumulator and remove unused sum_distinct module * feat: add avg distinct support for float64 type * feat: add avg distinct support for decimal * feat: more test for avg distinct in rust api * Remove DataFrame API tests for avg(distinct) * Remove proto test * Fix merge errors * Refactoring * Minor cleanup * Decimal slt tests for avg(distinct) * Fix state_fields for decimal distinct avg --------- Co-authored-by: YuNing Chen <[email protected]> Co-authored-by: Andrew Lamb <[email protected]> Co-authored-by: Dmitrii Blaginin <[email protected]>
1 parent d5065d6 commit 80baff7

File tree

4 files changed

+310
-39
lines changed

4 files changed

+310
-39
lines changed

datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod decimal;
1819
mod numeric;
1920

21+
pub use decimal::DecimalDistinctAvgAccumulator;
2022
pub use numeric::Float64DistinctAvgAccumulator;
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::{
19+
array::{ArrayRef, ArrowNumericType},
20+
datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType},
21+
};
22+
use datafusion_common::{Result, ScalarValue};
23+
use datafusion_expr_common::accumulator::Accumulator;
24+
use std::fmt::Debug;
25+
use std::mem::size_of_val;
26+
27+
use crate::aggregate::sum_distinct::DistinctSumAccumulator;
28+
use crate::utils::DecimalAverager;
29+
30+
/// Generic implementation of `AVG DISTINCT` for Decimal types.
31+
/// Handles both Decimal128Type and Decimal256Type.
32+
#[derive(Debug)]
33+
pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
34+
sum_accumulator: DistinctSumAccumulator<T>,
35+
sum_scale: i8,
36+
target_precision: u8,
37+
target_scale: i8,
38+
}
39+
40+
impl<T: DecimalType + Debug> DecimalDistinctAvgAccumulator<T> {
41+
pub fn with_decimal_params(
42+
sum_scale: i8,
43+
target_precision: u8,
44+
target_scale: i8,
45+
) -> Self {
46+
let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale);
47+
48+
Self {
49+
sum_accumulator: DistinctSumAccumulator::new(&data_type),
50+
sum_scale,
51+
target_precision,
52+
target_scale,
53+
}
54+
}
55+
}
56+
57+
impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
58+
for DecimalDistinctAvgAccumulator<T>
59+
{
60+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
61+
self.sum_accumulator.state()
62+
}
63+
64+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
65+
self.sum_accumulator.update_batch(values)
66+
}
67+
68+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
69+
self.sum_accumulator.merge_batch(states)
70+
}
71+
72+
fn evaluate(&mut self) -> Result<ScalarValue> {
73+
if self.sum_accumulator.distinct_count() == 0 {
74+
return ScalarValue::new_primitive::<T>(
75+
None,
76+
&T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
77+
);
78+
}
79+
80+
let sum_scalar = self.sum_accumulator.evaluate()?;
81+
82+
match sum_scalar {
83+
ScalarValue::Decimal128(Some(sum), _, _) => {
84+
let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
85+
self.sum_scale,
86+
self.target_precision,
87+
self.target_scale,
88+
)?;
89+
let avg = decimal_averager
90+
.avg(sum, self.sum_accumulator.distinct_count() as i128)?;
91+
Ok(ScalarValue::Decimal128(
92+
Some(avg),
93+
self.target_precision,
94+
self.target_scale,
95+
))
96+
}
97+
ScalarValue::Decimal256(Some(sum), _, _) => {
98+
let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
99+
self.sum_scale,
100+
self.target_precision,
101+
self.target_scale,
102+
)?;
103+
// `distinct_count` returns `u64`, but `avg` expects `i256`
104+
// first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow
105+
let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128;
106+
let count: i256 = i256::from_i128(distinct_cnt);
107+
let avg = decimal_averager.avg(sum, count)?;
108+
Ok(ScalarValue::Decimal256(
109+
Some(avg),
110+
self.target_precision,
111+
self.target_scale,
112+
))
113+
}
114+
115+
_ => unreachable!("Unsupported decimal type: {:?}", sum_scalar),
116+
}
117+
}
118+
119+
fn size(&self) -> usize {
120+
let fixed_size = size_of_val(self);
121+
122+
// Account for the size of the sum_accumulator with its contained values
123+
fixed_size + self.sum_accumulator.size()
124+
}
125+
}
126+
127+
#[cfg(test)]
128+
mod tests {
129+
use super::*;
130+
use arrow::array::{Decimal128Array, Decimal256Array};
131+
use std::sync::Arc;
132+
133+
#[test]
134+
fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
135+
let precision = 10_u8;
136+
let scale = 4_i8;
137+
let array = Decimal128Array::from(vec![
138+
Some(100_0000),
139+
Some(125_0000),
140+
Some(175_0000),
141+
Some(200_0000),
142+
Some(200_0000),
143+
Some(300_0000),
144+
None,
145+
None,
146+
])
147+
.with_precision_and_scale(precision, scale)?;
148+
149+
let mut accumulator =
150+
DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
151+
scale, 14, 8,
152+
);
153+
accumulator.update_batch(&[Arc::new(array)])?;
154+
155+
let result = accumulator.evaluate()?;
156+
let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8);
157+
assert_eq!(result, expected_result);
158+
159+
Ok(())
160+
}
161+
162+
#[test]
163+
fn test_decimal256_distinct_avg_accumulator() -> Result<()> {
164+
let precision = 50_u8;
165+
let scale = 2_i8;
166+
167+
let array = Decimal256Array::from(vec![
168+
Some(i256::from_i128(10_000)),
169+
Some(i256::from_i128(12_500)),
170+
Some(i256::from_i128(17_500)),
171+
Some(i256::from_i128(20_000)),
172+
Some(i256::from_i128(20_000)),
173+
Some(i256::from_i128(30_000)),
174+
None,
175+
None,
176+
])
177+
.with_precision_and_scale(precision, scale)?;
178+
179+
let mut accumulator =
180+
DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
181+
scale, 54, 6,
182+
);
183+
accumulator.update_batch(&[Arc::new(array)])?;
184+
185+
let result = accumulator.evaluate()?;
186+
let expected_result =
187+
ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6);
188+
assert_eq!(result, expected_result);
189+
190+
Ok(())
191+
}
192+
}

datafusion/functions-aggregate/src/average.rs

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use arrow::datatypes::{
2727
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
2828
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
2929
DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30+
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
3031
};
3132
use datafusion_common::{
3233
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
@@ -40,7 +41,9 @@ use datafusion_expr::{
4041
ReversedUDAF, Signature,
4142
};
4243

43-
use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator;
44+
use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
45+
DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
46+
};
4447
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
4548
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
4649
filtered_null_mask, set_nulls,
@@ -120,13 +123,36 @@ impl AggregateUDFImpl for Avg {
120123

121124
// instantiate specialized accumulator based for the type
122125
if acc_args.is_distinct {
123-
match &data_type {
126+
match (&data_type, acc_args.return_type()) {
124127
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
125-
Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
126-
_ => exec_err!("AVG(DISTINCT) for {} not supported", data_type),
128+
(Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
129+
130+
(
131+
Decimal128(_, scale),
132+
Decimal128(target_precision, target_scale),
133+
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal128Type>::with_decimal_params(
134+
*scale,
135+
*target_precision,
136+
*target_scale,
137+
))),
138+
139+
(
140+
Decimal256(_, scale),
141+
Decimal256(target_precision, target_scale),
142+
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal256Type>::with_decimal_params(
143+
*scale,
144+
*target_precision,
145+
*target_scale,
146+
))),
147+
148+
(dt, return_type) => exec_err!(
149+
"AVG(DISTINCT) for ({} --> {}) not supported",
150+
dt,
151+
return_type
152+
),
127153
}
128154
} else {
129-
match (&data_type, acc_args.return_field.data_type()) {
155+
match (&data_type, acc_args.return_type()) {
130156
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
131157
(
132158
Decimal128(sum_precision, sum_scale),
@@ -161,22 +187,31 @@ impl AggregateUDFImpl for Avg {
161187
}))
162188
}
163189

164-
_ => exec_err!(
165-
"AvgAccumulator for ({} --> {})",
166-
&data_type,
167-
acc_args.return_field.data_type()
168-
),
190+
(dt, return_type) => {
191+
exec_err!("AvgAccumulator for ({} --> {})", dt, return_type)
192+
}
169193
}
170194
}
171195
}
172196

173197
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
174198
if args.is_distinct {
175-
// Copied from datafusion_functions_aggregate::sum::Sum::state_fields
199+
// Decimal accumulator actually uses a different precision during accumulation,
200+
// see DecimalDistinctAvgAccumulator::with_decimal_params
201+
let dt = match args.input_fields[0].data_type() {
202+
DataType::Decimal128(_, scale) => {
203+
DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
204+
}
205+
DataType::Decimal256(_, scale) => {
206+
DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale)
207+
}
208+
_ => args.return_type().clone(),
209+
};
210+
// Similar to datafusion_functions_aggregate::sum::Sum::state_fields
176211
// since the accumulator uses DistinctSumAccumulator internally.
177212
Ok(vec![Field::new_list(
178213
format_state_name(args.name, "avg distinct"),
179-
Field::new_list_field(args.return_type().clone(), true),
214+
Field::new_list_field(dt, true),
180215
false,
181216
)
182217
.into()])

0 commit comments

Comments
 (0)