Skip to content

Commit 8a4d717

Browse files
committed
Support Decimal32/64 types
1 parent 13208e6 commit 8a4d717

File tree

12 files changed

+661
-87
lines changed

12 files changed

+661
-87
lines changed

datafusion/common/src/cast.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
2323
use crate::{downcast_value, Result};
2424
use arrow::array::{
25-
BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray,
26-
DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array,
27-
LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array,
25+
BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray,
26+
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
27+
Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray,
28+
UInt16Array,
2829
};
2930
use arrow::{
3031
array::{
@@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> {
9798
Ok(downcast_value!(array, UInt64Array))
9899
}
99100

101+
// Downcast Array to Decimal32Array
102+
pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> {
103+
Ok(downcast_value!(array, Decimal32Array))
104+
}
105+
106+
// Downcast Array to Decimal64Array
107+
pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> {
108+
Ok(downcast_value!(array, Decimal64Array))
109+
}
110+
100111
// Downcast Array to Decimal128Array
101112
pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> {
102113
Ok(downcast_value!(array, Decimal128Array))

datafusion/common/src/scalar/mod.rs

Lines changed: 338 additions & 36 deletions
Large diffs are not rendered by default.

datafusion/core/tests/fuzz_cases/record_batch_generator.rs

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ use std::sync::Arc;
2020
use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch};
2121
use arrow::datatypes::{
2222
ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type,
23-
Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
24-
DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type,
25-
Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
26-
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema,
27-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28-
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
23+
Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType,
24+
DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field,
25+
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
26+
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
27+
Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
28+
Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
2929
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
3030
UInt8Type,
3131
};
3232
use arrow_schema::{
3333
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
34-
DECIMAL256_MAX_SCALE,
34+
DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
35+
DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
3536
};
3637
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
3738
use rand::{rng, rngs::StdRng, Rng, SeedableRng};
@@ -104,6 +105,20 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec<ColumnDescr> {
104105
"duration_nanosecond",
105106
DataType::Duration(TimeUnit::Nanosecond),
106107
),
108+
ColumnDescr::new("decimal32", {
109+
let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION);
110+
let scale: i8 = rng.random_range(
111+
i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE),
112+
);
113+
DataType::Decimal32(precision, scale)
114+
}),
115+
ColumnDescr::new("decimal64", {
116+
let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION);
117+
let scale: i8 = rng.random_range(
118+
i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE),
119+
);
120+
DataType::Decimal64(precision, scale)
121+
}),
107122
ColumnDescr::new("decimal128", {
108123
let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION);
109124
let scale: i8 = rng.random_range(
@@ -682,6 +697,32 @@ impl RecordBatchGenerator {
682697
_ => unreachable!(),
683698
}
684699
}
700+
DataType::Decimal32(precision, scale) => {
701+
generate_decimal_array!(
702+
self,
703+
num_rows,
704+
max_num_distinct,
705+
null_pct,
706+
batch_gen_rng,
707+
array_gen_rng,
708+
precision,
709+
scale,
710+
Decimal32Type
711+
)
712+
}
713+
DataType::Decimal64(precision, scale) => {
714+
generate_decimal_array!(
715+
self,
716+
num_rows,
717+
max_num_distinct,
718+
null_pct,
719+
batch_gen_rng,
720+
array_gen_rng,
721+
precision,
722+
scale,
723+
Decimal64Type
724+
)
725+
}
685726
DataType::Decimal128(precision, scale) => {
686727
generate_decimal_array!(
687728
self,

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
use crate::signature::TypeSignature;
1919
use arrow::datatypes::{
2020
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
21-
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
21+
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
22+
DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
2223
};
2324

2425
use datafusion_common::{internal_err, plan_err, Result};
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
150151
DataType::Int64 => Ok(DataType::Int64),
151152
DataType::UInt64 => Ok(DataType::UInt64),
152153
DataType::Float64 => Ok(DataType::Float64),
154+
DataType::Decimal32(precision, scale) => {
155+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
156+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
158+
Ok(DataType::Decimal32(new_precision, *scale))
159+
}
160+
DataType::Decimal64(precision, scale) => {
161+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
162+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
164+
Ok(DataType::Decimal64(new_precision, *scale))
165+
}
153166
DataType::Decimal128(precision, scale) => {
154167
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
155168
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -196,6 +209,20 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
196209
/// Function return type of an average
197210
pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> {
198211
match arg_type {
212+
DataType::Decimal32(precision, scale) => {
213+
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
214+
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
215+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
216+
let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
217+
Ok(DataType::Decimal32(new_precision, new_scale))
218+
}
219+
DataType::Decimal64(precision, scale) => {
220+
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
221+
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
222+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
223+
let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
224+
Ok(DataType::Decimal64(new_precision, new_scale))
225+
}
199226
DataType::Decimal128(precision, scale) => {
200227
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
201228
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
@@ -222,6 +249,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
222249
/// Internal sum type of an average
223250
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
224251
match arg_type {
252+
DataType::Decimal32(precision, scale) => {
253+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
254+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
255+
Ok(DataType::Decimal32(new_precision, *scale))
256+
}
257+
DataType::Decimal64(precision, scale) => {
258+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
259+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
260+
Ok(DataType::Decimal64(new_precision, *scale))
261+
}
225262
DataType::Decimal128(precision, scale) => {
226263
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
227264
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
@@ -249,7 +286,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
249286
_ => matches!(
250287
arg_type,
251288
arg_type if NUMERICS.contains(arg_type)
252-
|| matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
289+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
253290
),
254291
}
255292
}
@@ -262,7 +299,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
262299
_ => matches!(
263300
arg_type,
264301
arg_type if NUMERICS.contains(arg_type)
265-
|| matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
302+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
266303
),
267304
}
268305
}
@@ -297,6 +334,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
297334
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
298335
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
299336
match &data_type {
337+
DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
338+
DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
300339
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
301340
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
302341
d if d.is_numeric() => Ok(DataType::Float64),

datafusion/expr/src/type_coercion/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
5151
| DataType::Float16
5252
| DataType::Float32
5353
| DataType::Float64
54+
| DataType::Decimal32(_, _)
55+
| DataType::Decimal64(_, _)
5456
| DataType::Decimal128(_, _)
5557
| DataType::Decimal256(_, _),
5658
)
@@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool {
8991

9092
/// Determine whether the given data type `dt` is a `Decimal`.
9193
pub fn is_decimal(dt: &DataType) -> bool {
92-
matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
94+
matches!(
95+
dt,
96+
DataType::Decimal32(_, _)
97+
| DataType::Decimal64(_, _)
98+
| DataType::Decimal128(_, _)
99+
| DataType::Decimal256(_, _)
100+
)
93101
}

datafusion/functions-aggregate/src/average.rs

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ use arrow::array::{
2424

2525
use arrow::compute::sum;
2626
use arrow::datatypes::{
27-
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28-
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29-
DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
30-
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
27+
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
28+
Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType,
29+
DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit,
30+
UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
31+
DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
3132
};
3233
use datafusion_common::{
3334
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
@@ -127,6 +128,22 @@ impl AggregateUDFImpl for Avg {
127128
// Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation
128129
(Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())),
129130

131+
(
132+
Decimal32(_, scale),
133+
Decimal32(target_precision, target_scale),
134+
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
135+
*scale,
136+
*target_precision,
137+
*target_scale,
138+
))),
139+
(
140+
Decimal64(_, scale),
141+
Decimal64(target_precision, target_scale),
142+
) => Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
143+
*scale,
144+
*target_precision,
145+
*target_scale,
146+
))),
130147
(
131148
Decimal128(_, scale),
132149
Decimal128(target_precision, target_scale),
@@ -154,6 +171,28 @@ impl AggregateUDFImpl for Avg {
154171
} else {
155172
match (&data_type, acc_args.return_type()) {
156173
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
174+
(
175+
Decimal32(sum_precision, sum_scale),
176+
Decimal32(target_precision, target_scale),
177+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
178+
sum: None,
179+
count: 0,
180+
sum_scale: *sum_scale,
181+
sum_precision: *sum_precision,
182+
target_precision: *target_precision,
183+
target_scale: *target_scale,
184+
})),
185+
(
186+
Decimal64(sum_precision, sum_scale),
187+
Decimal64(target_precision, target_scale),
188+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
189+
sum: None,
190+
count: 0,
191+
sum_scale: *sum_scale,
192+
sum_precision: *sum_precision,
193+
target_precision: *target_precision,
194+
target_scale: *target_scale,
195+
})),
157196
(
158197
Decimal128(sum_precision, sum_scale),
159198
Decimal128(target_precision, target_scale),
@@ -199,6 +238,12 @@ impl AggregateUDFImpl for Avg {
199238
// Decimal accumulator actually uses a different precision during accumulation,
200239
// see DecimalDistinctAvgAccumulator::with_decimal_params
201240
let dt = match args.input_fields[0].data_type() {
241+
DataType::Decimal32(_, scale) => {
242+
DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
243+
}
244+
DataType::Decimal64(_, scale) => {
245+
DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
246+
}
202247
DataType::Decimal128(_, scale) => {
203248
DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
204249
}
@@ -237,7 +282,11 @@ impl AggregateUDFImpl for Avg {
237282
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
238283
matches!(
239284
args.return_field.data_type(),
240-
DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
285+
DataType::Float64
286+
| DataType::Decimal32(_, _)
287+
| DataType::Decimal64(_, _)
288+
| DataType::Decimal128(_, _)
289+
| DataType::Duration(_)
241290
) && !args.is_distinct
242291
}
243292

@@ -257,6 +306,44 @@ impl AggregateUDFImpl for Avg {
257306
|sum: f64, count: u64| Ok(sum / count as f64),
258307
)))
259308
}
309+
(
310+
Decimal32(_sum_precision, sum_scale),
311+
Decimal32(target_precision, target_scale),
312+
) => {
313+
let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
314+
*sum_scale,
315+
*target_precision,
316+
*target_scale,
317+
)?;
318+
319+
let avg_fn =
320+
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
321+
322+
Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
323+
&data_type,
324+
args.return_field.data_type(),
325+
avg_fn,
326+
)))
327+
}
328+
(
329+
Decimal64(_sum_precision, sum_scale),
330+
Decimal64(target_precision, target_scale),
331+
) => {
332+
let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
333+
*sum_scale,
334+
*target_precision,
335+
*target_scale,
336+
)?;
337+
338+
let avg_fn =
339+
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
340+
341+
Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
342+
&data_type,
343+
args.return_field.data_type(),
344+
avg_fn,
345+
)))
346+
}
260347
(
261348
Decimal128(_sum_precision, sum_scale),
262349
Decimal128(target_precision, target_scale),

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ use arrow::array::{
3030
use arrow::buffer::{BooleanBuffer, NullBuffer};
3131
use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
3232
use arrow::datatypes::{
33-
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef,
34-
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
36-
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
37-
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
38-
UInt8Type,
33+
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type,
34+
Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type,
35+
Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType,
36+
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
38+
UInt32Type, UInt64Type, UInt8Type,
3939
};
4040
use datafusion_common::cast::as_boolean_array;
4141
use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
@@ -234,6 +234,8 @@ impl AggregateUDFImpl for FirstValue {
234234
DataType::Float32 => create_accumulator::<Float32Type>(args),
235235
DataType::Float64 => create_accumulator::<Float64Type>(args),
236236

237+
DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(args),
238+
DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(args),
237239
DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
238240
DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
239241

0 commit comments

Comments
 (0)