Skip to content

Commit e8de52d

Browse files
committed
Fix bugs, tests, handle more aggregate functions and schema
1 parent 8a4d717 commit e8de52d

File tree

6 files changed

+158
-17
lines changed

6 files changed

+158
-17
lines changed

datafusion/common/src/dfschema.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,14 @@ impl DFSchema {
798798
.zip(iter2)
799799
.all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2))
800800
}
801+
(
802+
DataType::Decimal32(_l_precision, _l_scale),
803+
DataType::Decimal32(_r_precision, _r_scale),
804+
) => true,
805+
(
806+
DataType::Decimal64(_l_precision, _l_scale),
807+
DataType::Decimal64(_r_precision, _r_scale),
808+
) => true,
801809
(
802810
DataType::Decimal128(_l_precision, _l_scale),
803811
DataType::Decimal128(_r_precision, _r_scale),
@@ -1056,6 +1064,12 @@ fn format_simple_data_type(data_type: &DataType) -> String {
10561064
DataType::Dictionary(_, value_type) => {
10571065
format_simple_data_type(value_type.as_ref())
10581066
}
1067+
DataType::Decimal32(precision, scale) => {
1068+
format!("decimal32({precision}, {scale})")
1069+
}
1070+
DataType::Decimal64(precision, scale) => {
1071+
format!("decimal64({precision}, {scale})")
1072+
}
10591073
DataType::Decimal128(precision, scale) => {
10601074
format!("decimal128({precision}, {scale})")
10611075
}
@@ -1794,6 +1808,27 @@ mod tests {
17941808
&DataType::Int16
17951809
));
17961810

1811+
// Succeeds if decimal precision and scale are different
1812+
assert!(DFSchema::datatype_is_semantically_equal(
1813+
&DataType::Decimal32(1, 2),
1814+
&DataType::Decimal32(2, 1),
1815+
));
1816+
1817+
assert!(DFSchema::datatype_is_semantically_equal(
1818+
&DataType::Decimal64(1, 2),
1819+
&DataType::Decimal64(2, 1),
1820+
));
1821+
1822+
assert!(DFSchema::datatype_is_semantically_equal(
1823+
&DataType::Decimal128(1, 2),
1824+
&DataType::Decimal128(2, 1),
1825+
));
1826+
1827+
assert!(DFSchema::datatype_is_semantically_equal(
1828+
&DataType::Decimal256(1, 2),
1829+
&DataType::Decimal256(2, 1),
1830+
));
1831+
17971832
// Test lists
17981833

17991834
// Succeeds if both have the same element type, disregards names and nullability
@@ -2377,6 +2412,8 @@ mod tests {
23772412
),
23782413
false,
23792414
),
2415+
Field::new("decimal32", DataType::Decimal32(9, 4), true),
2416+
Field::new("decimal64", DataType::Decimal64(9, 4), true),
23802417
Field::new("decimal128", DataType::Decimal128(18, 4), true),
23812418
Field::new("decimal256", DataType::Decimal256(38, 10), false),
23822419
Field::new("date32", DataType::Date32, true),
@@ -2408,6 +2445,8 @@ mod tests {
24082445
|-- fixed_size_binary: fixed_size_binary (nullable = true)
24092446
|-- fixed_size_list: fixed size list (nullable = false)
24102447
| |-- item: int32 (nullable = true)
2448+
|-- decimal32: decimal32(9, 4) (nullable = true)
2449+
|-- decimal64: decimal64(9, 4) (nullable = true)
24112450
|-- decimal128: decimal128(18, 4) (nullable = true)
24122451
|-- decimal256: decimal256(38, 10) (nullable = false)
24132452
|-- date32: date32 (nullable = true)

datafusion/common/src/scalar/mod.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,12 @@ impl ScalarValue {
10871087
DataType::UInt16 => ScalarValue::UInt16(None),
10881088
DataType::UInt32 => ScalarValue::UInt32(None),
10891089
DataType::UInt64 => ScalarValue::UInt64(None),
1090+
DataType::Decimal32(precision, scale) => {
1091+
ScalarValue::Decimal32(None, *precision, *scale)
1092+
}
1093+
DataType::Decimal64(precision, scale) => {
1094+
ScalarValue::Decimal64(None, *precision, *scale)
1095+
}
10901096
DataType::Decimal128(precision, scale) => {
10911097
ScalarValue::Decimal128(None, *precision, *scale)
10921098
}
@@ -3179,6 +3185,24 @@ impl ScalarValue {
31793185
scale: i8,
31803186
) -> Result<ScalarValue> {
31813187
match array.data_type() {
3188+
DataType::Decimal32(_, _) => {
3189+
let array = as_decimal32_array(array)?;
3190+
if array.is_null(index) {
3191+
Ok(ScalarValue::Decimal32(None, precision, scale))
3192+
} else {
3193+
let value = array.value(index);
3194+
Ok(ScalarValue::Decimal32(Some(value), precision, scale))
3195+
}
3196+
}
3197+
DataType::Decimal64(_, _) => {
3198+
let array = as_decimal64_array(array)?;
3199+
if array.is_null(index) {
3200+
Ok(ScalarValue::Decimal64(None, precision, scale))
3201+
} else {
3202+
let value = array.value(index);
3203+
Ok(ScalarValue::Decimal64(Some(value), precision, scale))
3204+
}
3205+
}
31823206
DataType::Decimal128(_, _) => {
31833207
let array = as_decimal128_array(array)?;
31843208
if array.is_null(index) {
@@ -3197,7 +3221,9 @@ impl ScalarValue {
31973221
Ok(ScalarValue::Decimal256(Some(value), precision, scale))
31983222
}
31993223
}
3200-
_ => _internal_err!("Unsupported decimal type"),
3224+
other => {
3225+
unreachable!("Invalid type isn't decimal: {other:?}")
3226+
}
32013227
}
32023228
}
32033229

@@ -3311,6 +3337,16 @@ impl ScalarValue {
33113337

33123338
Ok(match array.data_type() {
33133339
DataType::Null => ScalarValue::Null,
3340+
DataType::Decimal32(precision, scale) => {
3341+
ScalarValue::get_decimal_value_from_array(
3342+
array, index, *precision, *scale,
3343+
)?
3344+
}
3345+
DataType::Decimal64(precision, scale) => {
3346+
ScalarValue::get_decimal_value_from_array(
3347+
array, index, *precision, *scale,
3348+
)?
3349+
}
33143350
DataType::Decimal128(precision, scale) => {
33153351
ScalarValue::get_decimal_value_from_array(
33163352
array, index, *precision, *scale,

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
2020
use arrow::array::{
2121
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array,
22-
Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray,
23-
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray,
24-
FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int16Array,
25-
Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
26-
IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
27-
StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
28-
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
29-
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
30-
UInt64Array, UInt8Array,
22+
Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
23+
DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray,
24+
DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
25+
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
26+
IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
27+
LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
28+
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
29+
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
30+
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
3131
};
3232
use arrow::compute;
3333
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
@@ -144,6 +144,32 @@ macro_rules! min_max {
144144
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
145145
Ok(match ($VALUE, $DELTA) {
146146
(ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
147+
(
148+
lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss),
149+
rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss)
150+
) => {
151+
if lhsp.eq(rhsp) && lhss.eq(rhss) {
152+
typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss)
153+
} else {
154+
return internal_err!(
155+
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
156+
(lhs, rhs)
157+
);
158+
}
159+
}
160+
(
161+
lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss),
162+
rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss)
163+
) => {
164+
if lhsp.eq(rhsp) && lhss.eq(rhss) {
165+
typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss)
166+
} else {
167+
return internal_err!(
168+
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
169+
(lhs, rhs)
170+
);
171+
}
172+
}
147173
(
148174
lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
149175
rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
@@ -513,6 +539,26 @@ macro_rules! min_max_batch {
513539
($VALUES:expr, $OP:ident) => {{
514540
match $VALUES.data_type() {
515541
DataType::Null => ScalarValue::Null,
542+
DataType::Decimal32(precision, scale) => {
543+
typed_min_max_batch!(
544+
$VALUES,
545+
Decimal32Array,
546+
Decimal32,
547+
$OP,
548+
precision,
549+
scale
550+
)
551+
}
552+
DataType::Decimal64(precision, scale) => {
553+
typed_min_max_batch!(
554+
$VALUES,
555+
Decimal64Array,
556+
Decimal64,
557+
$OP,
558+
precision,
559+
scale
560+
)
561+
}
516562
DataType::Decimal128(precision, scale) => {
517563
typed_min_max_batch!(
518564
$VALUES,

datafusion/functions-aggregate/src/median.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ use arrow::{
3535

3636
use arrow::array::Array;
3737
use arrow::array::ArrowNativeTypeOp;
38-
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef};
38+
use arrow::datatypes::{
39+
ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef,
40+
};
3941

4042
use datafusion_common::{
4143
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
@@ -166,6 +168,8 @@ impl AggregateUDFImpl for Median {
166168
DataType::Float16 => helper!(Float16Type, dt),
167169
DataType::Float32 => helper!(Float32Type, dt),
168170
DataType::Float64 => helper!(Float64Type, dt),
171+
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
172+
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
169173
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
170174
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
171175
_ => Err(DataFusionError::NotImplemented(format!(
@@ -205,6 +209,8 @@ impl AggregateUDFImpl for Median {
205209
DataType::Float16 => helper!(Float16Type, dt),
206210
DataType::Float32 => helper!(Float32Type, dt),
207211
DataType::Float64 => helper!(Float64Type, dt),
212+
DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
213+
DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
208214
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
209215
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
210216
_ => Err(DataFusionError::NotImplemented(format!(

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ mod min_max_struct;
2323

2424
use arrow::array::ArrayRef;
2525
use arrow::datatypes::{
26-
DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType,
27-
DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type,
28-
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
29-
UInt32Type, UInt64Type, UInt8Type,
26+
DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
27+
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
28+
DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
29+
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
3030
};
3131
use datafusion_common::stats::Precision;
3232
use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result};
@@ -320,6 +320,12 @@ impl AggregateUDFImpl for Max {
320320
Duration(Nanosecond) => {
321321
primitive_max_accumulator!(data_type, i64, DurationNanosecondType)
322322
}
323+
Decimal32(_, _) => {
324+
primitive_max_accumulator!(data_type, i32, Decimal32Type)
325+
}
326+
Decimal64(_, _) => {
327+
primitive_max_accumulator!(data_type, i64, Decimal64Type)
328+
}
323329
Decimal128(_, _) => {
324330
primitive_max_accumulator!(data_type, i128, Decimal128Type)
325331
}
@@ -518,6 +524,8 @@ impl AggregateUDFImpl for Min {
518524
| Float16
519525
| Float32
520526
| Float64
527+
| Decimal32(_, _)
528+
| Decimal64(_, _)
521529
| Decimal128(_, _)
522530
| Decimal256(_, _)
523531
| Date32
@@ -599,6 +607,12 @@ impl AggregateUDFImpl for Min {
599607
Duration(Nanosecond) => {
600608
primitive_min_accumulator!(data_type, i64, DurationNanosecondType)
601609
}
610+
Decimal32(_, _) => {
611+
primitive_min_accumulator!(data_type, i32, Decimal32Type)
612+
}
613+
Decimal64(_, _) => {
614+
primitive_min_accumulator!(data_type, i64, Decimal64Type)
615+
}
602616
Decimal128(_, _) => {
603617
primitive_min_accumulator!(data_type, i128, Decimal128Type)
604618
}

datafusion/functions-aggregate/src/sum.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ impl AggregateUDFImpl for Sum {
176176
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
177177
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
178178
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
179-
Ok(DataType::Decimal128(new_precision, *scale))
179+
Ok(DataType::Decimal32(new_precision, *scale))
180180
}
181181
DataType::Decimal64(precision, scale) => {
182182
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
183183
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
184184
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
185-
Ok(DataType::Decimal128(new_precision, *scale))
185+
Ok(DataType::Decimal64(new_precision, *scale))
186186
}
187187
DataType::Decimal128(precision, scale) => {
188188
// in the spark, the result type is DECIMAL(min(38,precision+10), s)

0 commit comments

Comments
 (0)