1818use crate :: { VariantArray , VariantArrayBuilder } ;
1919use arrow:: array:: { Array , AsArray } ;
2020use arrow:: datatypes:: {
21- Float32Type , Float64Type , Int16Type , Int32Type , Int64Type , Int8Type , UInt16Type , UInt32Type ,
22- UInt64Type , UInt8Type ,
21+ Float16Type , Float32Type , Float64Type , Int16Type , Int32Type , Int64Type , Int8Type , UInt16Type ,
22+ UInt32Type , UInt64Type , UInt8Type ,
2323} ;
2424use arrow_schema:: { ArrowError , DataType } ;
25+ use half:: f16;
2526use parquet_variant:: Variant ;
2627
2728/// Convert the input array of a specific primitive type to a `VariantArray`
@@ -39,6 +40,22 @@ macro_rules! primitive_conversion {
3940 } } ;
4041}
4142
43+ /// Convert the input array to a `VariantArray` row by row,
44+ /// transforming each element with `cast_fn`
45+ macro_rules! cast_conversion {
46+ ( $t: ty, $cast_fn: expr, $input: expr, $builder: expr) => { {
47+ let array = $input. as_primitive:: <$t>( ) ;
48+ for i in 0 ..array. len( ) {
49+ if array. is_null( i) {
50+ $builder. append_null( ) ;
51+ continue ;
52+ }
53+ let cast_value = $cast_fn( array. value( i) ) ;
54+ $builder. append_variant( Variant :: from( cast_value) ) ;
55+ }
56+ } } ;
57+ }
58+
4259/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
4360/// need to convert a specific data type
4461///
@@ -92,6 +109,9 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
92109 DataType :: UInt64 => {
93110 primitive_conversion ! ( UInt64Type , input, builder) ;
94111 }
112+ DataType :: Float16 => {
113+ cast_conversion ! ( Float16Type , |v: f16| -> f32 { v. into( ) } , input, builder) ;
114+ }
95115 DataType :: Float32 => {
96116 primitive_conversion ! ( Float32Type , input, builder) ;
97117 }
@@ -115,8 +135,8 @@ pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
115135mod tests {
116136 use super :: * ;
117137 use arrow:: array:: {
118- ArrayRef , Float32Array , Float64Array , Int16Array , Int32Array , Int64Array , Int8Array ,
119- UInt16Array , UInt32Array , UInt64Array , UInt8Array ,
138+ ArrayRef , Float16Array , Float32Array , Float64Array , Int16Array , Int32Array , Int64Array ,
139+ Int8Array , UInt16Array , UInt32Array , UInt64Array , UInt8Array ,
120140 } ;
121141 use parquet_variant:: { Variant , VariantDecimal16 } ;
122142 use std:: sync:: Arc ;
@@ -284,6 +304,28 @@ mod tests {
284304 )
285305 }
286306
307+ #[ test]
308+ fn test_cast_to_variant_float16 ( ) {
309+ run_test (
310+ Arc :: new ( Float16Array :: from ( vec ! [
311+ Some ( f16:: MIN ) ,
312+ None ,
313+ Some ( f16:: from_f32( -1.5 ) ) ,
314+ Some ( f16:: from_f32( 0.0 ) ) ,
315+ Some ( f16:: from_f32( 1.5 ) ) ,
316+ Some ( f16:: MAX ) ,
317+ ] ) ) ,
318+ vec ! [
319+ Some ( Variant :: Float ( f16:: MIN . into( ) ) ) ,
320+ None ,
321+ Some ( Variant :: Float ( -1.5 ) ) ,
322+ Some ( Variant :: Float ( 0.0 ) ) ,
323+ Some ( Variant :: Float ( 1.5 ) ) ,
324+ Some ( Variant :: Float ( f16:: MAX . into( ) ) ) ,
325+ ] ,
326+ )
327+ }
328+
287329 #[ test]
288330 fn test_cast_to_variant_float32 ( ) {
289331 run_test (
0 commit comments