Skip to content

Commit 554cafa

Browse files
Implement DataType::Float16 => Variant::Float (#8073)
# Which issue does this PR close? - Closes #8057 # Rationale for this change Adds Float16 conversion to the `cast_to_variant` kernel # What changes are included in this PR? - a macro to make converting array type that require a cast simpler - conversion of `DataType::Float16` => `Variant::Float` # Are these changes tested? Yes, additional unit tests have been added. # Are there any user-facing changes? Yes, adds new type conversion to kernel
1 parent 0710ecc commit 554cafa

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

parquet-variant-compute/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ rust-version = { workspace = true }
3333
[dependencies]
3434
arrow = { workspace = true }
3535
arrow-schema = { workspace = true }
36+
half = { version = "2.1", default-features = false }
3637
parquet-variant = { workspace = true }
3738
parquet-variant-json = { workspace = true }
3839

@@ -49,4 +50,3 @@ arrow = { workspace = true, features = ["test_utils"] }
4950
[[bench]]
5051
name = "variant_kernels"
5152
harness = false
52-

parquet-variant-compute/src/cast_to_variant.rs

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
use crate::{VariantArray, VariantArrayBuilder};
1919
use arrow::array::{Array, AsArray};
2020
use arrow::datatypes::{
21-
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
22-
UInt64Type, UInt8Type,
21+
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
22+
UInt32Type, UInt64Type, UInt8Type,
2323
};
2424
use arrow_schema::{ArrowError, DataType};
25+
use half::f16;
2526
use parquet_variant::Variant;
2627

2728
/// Convert the input array of a specific primitive type to a `VariantArray`
@@ -39,6 +40,22 @@ macro_rules! primitive_conversion {
3940
}};
4041
}
4142

43+
/// Convert the input array to a `VariantArray` row by row,
44+
/// transforming each element with `cast_fn`
45+
macro_rules! cast_conversion {
46+
($t:ty, $cast_fn:expr, $input:expr, $builder:expr) => {{
47+
let array = $input.as_primitive::<$t>();
48+
for i in 0..array.len() {
49+
if array.is_null(i) {
50+
$builder.append_null();
51+
continue;
52+
}
53+
let cast_value = $cast_fn(array.value(i));
54+
$builder.append_variant(Variant::from(cast_value));
55+
}
56+
}};
57+
}
58+
4259
/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
4360
/// need to convert a specific data type
4461
///
@@ -92,6 +109,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
92109
DataType::UInt64 => {
93110
primitive_conversion!(UInt64Type, input, builder);
94111
}
112+
DataType::Float16 => {
113+
cast_conversion!(Float16Type, |v: f16| -> f32 { v.into() }, input, builder);
114+
}
95115
DataType::Float32 => {
96116
primitive_conversion!(Float32Type, input, builder);
97117
}
@@ -115,8 +135,8 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
115135
mod tests {
116136
use super::*;
117137
use arrow::array::{
118-
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
119-
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
138+
ArrayRef, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
139+
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
120140
};
121141
use parquet_variant::{Variant, VariantDecimal16};
122142
use std::sync::Arc;
@@ -284,6 +304,28 @@ mod tests {
284304
)
285305
}
286306

307+
#[test]
308+
fn test_cast_to_variant_float16() {
309+
run_test(
310+
Arc::new(Float16Array::from(vec![
311+
Some(f16::MIN),
312+
None,
313+
Some(f16::from_f32(-1.5)),
314+
Some(f16::from_f32(0.0)),
315+
Some(f16::from_f32(1.5)),
316+
Some(f16::MAX),
317+
])),
318+
vec![
319+
Some(Variant::Float(f16::MIN.into())),
320+
None,
321+
Some(Variant::Float(-1.5)),
322+
Some(Variant::Float(0.0)),
323+
Some(Variant::Float(1.5)),
324+
Some(Variant::Float(f16::MAX.into())),
325+
],
326+
)
327+
}
328+
287329
#[test]
288330
fn test_cast_to_variant_float32() {
289331
run_test(

0 commit comments

Comments
 (0)