Skip to content

Commit 4801e40

Browse files
committed
feat: add float64 nan value counts support
1 parent 89ffcdc commit 4801e40

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

crates/iceberg/src/writer/file_writer/parquet_writer.rs

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::collections::HashMap;
2222
use std::sync::atomic::AtomicI64;
2323
use std::sync::Arc;
2424

25-
use arrow_array::Float32Array;
25+
use arrow_array::{Float32Array, Float64Array};
2626
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
2727
use bytes::Bytes;
2828
use futures::future::BoxFuture;
@@ -418,6 +418,14 @@ impl FileWriter for ParquetWriter {
418418
.filter(|value| value.map_or(false, |v| v.is_nan()))
419419
.count() as u64
420420
}
421+
DataType::Float64 => {
422+
let float_array = col.as_any().downcast_ref::<Float64Array>().unwrap();
423+
424+
float_array
425+
.iter()
426+
.filter(|value| value.map_or(false, |v| v.is_nan()))
427+
.count() as u64
428+
}
421429
_ => 0,
422430
};
423431

@@ -682,6 +690,7 @@ mod tests {
682690
assert_eq!(visitor.name_to_id, expect);
683691
}
684692

693+
// TODO(feniljain): Remove nan value count test from here
685694
#[tokio::test]
686695
async fn test_parquet_writer() -> Result<()> {
687696
let temp_dir = TempDir::new().unwrap();
@@ -774,6 +783,103 @@ mod tests {
774783
Ok(())
775784
}
776785

786+
#[tokio::test]
787+
async fn test_parquet_writer_for_nan_value_counts() -> Result<()> {
788+
let temp_dir = TempDir::new().unwrap();
789+
let file_io = FileIOBuilder::new_fs_io().build().unwrap();
790+
let location_gen =
791+
MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string());
792+
let file_name_gen =
793+
DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
794+
795+
// prepare data
796+
let schema = {
797+
let fields =
798+
vec![
799+
// TODO(feniljain):
800+
// Types:
801+
// [X] Primitive
802+
// [ ] Struct
803+
// [ ] List
804+
// [ ] Map
805+
arrow_schema::Field::new("col", arrow_schema::DataType::Float32, true)
806+
.with_metadata(HashMap::from([(
807+
PARQUET_FIELD_ID_META_KEY.to_string(),
808+
"0".to_string(),
809+
)])),
810+
arrow_schema::Field::new("col1", arrow_schema::DataType::Float64, true)
811+
.with_metadata(HashMap::from([(
812+
PARQUET_FIELD_ID_META_KEY.to_string(),
813+
"1".to_string(),
814+
)])),
815+
];
816+
Arc::new(arrow_schema::Schema::new(fields))
817+
};
818+
819+
let float_32_col = Arc::new(Float32Array::from_iter_values_with_nulls(
820+
[1.0_f32, f32::NAN, 2.0, 2.0].into_iter(),
821+
None,
822+
)) as ArrayRef;
823+
824+
let float_64_col = Arc::new(Float64Array::from_iter_values_with_nulls(
825+
[1.0_f64, f64::NAN, 2.0, 2.0].into_iter(),
826+
None,
827+
)) as ArrayRef;
828+
829+
let to_write =
830+
RecordBatch::try_new(schema.clone(), vec![float_32_col, float_64_col]).unwrap();
831+
832+
// write data
833+
let mut pw = ParquetWriterBuilder::new(
834+
WriterProperties::builder().build(),
835+
Arc::new(to_write.schema().as_ref().try_into().unwrap()),
836+
file_io.clone(),
837+
location_gen,
838+
file_name_gen,
839+
)
840+
.build()
841+
.await?;
842+
843+
pw.write(&to_write).await?;
844+
let res = pw.close().await?;
845+
assert_eq!(res.len(), 1);
846+
let data_file = res
847+
.into_iter()
848+
.next()
849+
.unwrap()
850+
// Put dummy field for build successfully.
851+
.content(crate::spec::DataContentType::Data)
852+
.partition(Struct::empty())
853+
.build()
854+
.unwrap();
855+
856+
// check data file
857+
assert_eq!(data_file.record_count(), 4);
858+
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 4), (1, 4)]));
859+
assert_eq!(
860+
*data_file.lower_bounds(),
861+
HashMap::from([(0, Datum::float(1.0)), (1, Datum::double(1.0))])
862+
);
863+
assert_eq!(
864+
*data_file.upper_bounds(),
865+
HashMap::from([(0, Datum::float(2.0)), (1, Datum::double(2.0))])
866+
);
867+
assert_eq!(
868+
*data_file.null_value_counts(),
869+
HashMap::from([(0, 0), (1, 0)])
870+
);
871+
assert_eq!(
872+
*data_file.nan_value_counts(),
873+
HashMap::from([(0, 1), (1, 1)])
874+
);
875+
876+
// check the written file
877+
let expect_batch = concat_batches(&schema, vec![&to_write]).unwrap();
878+
check_parquet_data_file(&file_io, &data_file, &expect_batch).await;
879+
880+
Ok(())
881+
}
882+
777883
#[tokio::test]
778884
async fn test_parquet_writer_with_complex_schema() -> Result<()> {
779885
let temp_dir = TempDir::new().unwrap();

0 commit comments

Comments
 (0)