diff --git a/Cargo.lock b/Cargo.lock index 14d7cf78e..e9a5842bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2950,6 +2950,7 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-schema", + "arrow-select", "async-trait", "autometrics", "base64 0.21.0", diff --git a/server/Cargo.toml b/server/Cargo.toml index 569cda3d4..ce02794a4 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -19,6 +19,7 @@ arrow-schema = { version = "36.0.0", features = ["serde"] } arrow-array = { version = "36.0.0" } arrow-json = "36.0.0" arrow-ipc = "36.0.0" +arrow-select = "36.0.0" async-trait = "0.1" base64 = "0.21" bytes = "1.4" diff --git a/server/src/alerts/mod.rs b/server/src/alerts/mod.rs index c0c1c65a3..8dc7d53b9 100644 --- a/server/src/alerts/mod.rs +++ b/server/src/alerts/mod.rs @@ -30,6 +30,7 @@ pub mod rule; pub mod target; use crate::metrics::ALERTS_STATES; +use crate::utils::arrow::get_field; use crate::utils::uid; use crate::CONFIG; use crate::{storage, utils}; @@ -135,7 +136,7 @@ impl Message { // checks if message (with a column name) is valid (i.e. the column name is present in the schema) pub fn valid(&self, schema: &Schema, column: Option<&str>) -> bool { if let Some(col) = column { - return schema.field_with_name(col).is_ok(); + return get_field(schema, col).is_some(); } true } diff --git a/server/src/event.rs b/server/src/event.rs index dcc061c7a..152134d0e 100644 --- a/server/src/event.rs +++ b/server/src/event.rs @@ -22,9 +22,8 @@ mod writer; use arrow_array::RecordBatch; use arrow_schema::{Field, Schema}; +use itertools::Itertools; -use std::collections::HashMap; -use std::ops::DerefMut; use std::sync::Arc; use crate::metadata; @@ -42,6 +41,7 @@ pub struct Event { pub rb: RecordBatch, pub origin_format: &'static str, pub origin_size: u64, + pub is_first_event: bool, } // Events holds the schema related to a each event for a single log stream @@ -50,7 +50,7 @@ impl Event { let key = get_schema_key(&self.rb.schema().fields); let num_rows = self.rb.num_rows() as u64; - if self.is_first_event(metadata::STREAM_INFO.schema(&self.stream_name)?.as_ref()) { + if self.is_first_event { commit_schema(&self.stream_name, self.rb.schema())?; } @@ -73,25 +73,6 @@ impl Event { Ok(()) } - fn is_first_event(&self, stream_schema: &Schema) -> bool { - let mut stream_fields = stream_schema.fields().iter(); - let event_schema = self.rb.schema(); - let event_fields = event_schema.fields(); - - for field in event_fields { - loop { - let Some(stream_field) = stream_fields.next() else { return true }; - if stream_field.name() == field.name() { - break; - } else { - continue; - } - } - } - - false - } - // event process all events after the 1st event. Concatenates record batches // and puts them in memory store for each event. fn process_event( @@ -104,10 +85,10 @@ impl Event { } } -pub fn get_schema_key(fields: &Vec) -> String { +pub fn get_schema_key(fields: &[Field]) -> String { // Fields must be sorted let mut hasher = xxhash_rust::xxh3::Xxh3::new(); - for field in fields { + for field in fields.iter().sorted_by_key(|v| v.name()) { hasher.update(field.name().as_bytes()) } let hash = hasher.digest(); @@ -117,36 +98,17 @@ pub fn get_schema_key(fields: &Vec) -> String { pub fn commit_schema(stream_name: &str, schema: Arc) -> Result<(), EventError> { let mut stream_metadata = metadata::STREAM_INFO.write().expect("lock poisoned"); - let mut schema = Schema::try_merge(vec![ - schema.as_ref().clone(), - stream_metadata.get_unchecked(stream_name).as_ref().clone(), - ]) - .unwrap(); - schema.fields.sort_by(|a, b| a.name().cmp(b.name())); - - stream_metadata.set_unchecked(stream_name, Arc::new(schema)); + let map = &mut stream_metadata + .get_mut(stream_name) + .expect("map has entry for this stream name") + .schema; + let current_schema = Schema::new(map.values().cloned().collect()); + let schema = Schema::try_merge(vec![current_schema, schema.as_ref().clone()])?; + map.clear(); + map.extend(schema.fields.into_iter().map(|f| (f.name().clone(), f))); Ok(()) } -trait UncheckedOp: DerefMut> { - fn get_unchecked(&self, stream_name: &str) -> Arc { - let schema = &self - .get(stream_name) - .expect("map has entry for this stream name") - .schema; - - Arc::clone(schema) - } - - fn set_unchecked(&mut self, stream_name: &str, schema: Arc) { - self.get_mut(stream_name) - .expect("map has entry for this stream name") - .schema = schema - } -} - -impl>> UncheckedOp for T {} - pub mod error { use arrow_schema::ArrowError; @@ -167,57 +129,3 @@ pub mod error { ObjectStorage(#[from] ObjectStorageError), } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow_array::RecordBatch; - use arrow_schema::{DataType, Field, Schema}; - - use super::Event; - - fn test_rb(fields: Vec) -> RecordBatch { - RecordBatch::new_empty(Arc::new(Schema::new(fields))) - } - - fn test_event(fields: Vec) -> Event { - Event { - stream_name: "".to_string(), - rb: test_rb(fields), - origin_format: "none", - origin_size: 0, - } - } - - #[test] - fn new_field_is_new_event() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - ]); - - let new_event = test_event(vec![ - Field::new("a", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ]); - - assert!(new_event.is_first_event(&schema)); - } - - #[test] - fn same_field_not_is_new_event() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ]); - - let new_event = test_event(vec![ - Field::new("a", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ]); - - assert!(!new_event.is_first_event(&schema)); - } -} diff --git a/server/src/event/format.rs b/server/src/event/format.rs index b0554d4f9..e337f991e 100644 --- a/server/src/event/format.rs +++ b/server/src/event/format.rs @@ -17,14 +17,14 @@ * */ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use anyhow::{anyhow, Error as AnyError}; use arrow_array::{RecordBatch, StringArray, TimestampMillisecondArray}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use chrono::Utc; -use crate::utils; +use crate::utils::{self, arrow::get_field}; use super::{DEFAULT_METADATA_KEY, DEFAULT_TAGS_KEY, DEFAULT_TIMESTAMP_KEY}; @@ -33,56 +33,60 @@ pub mod json; type Tags = String; type Metadata = String; +// Global Trait for event format +// This trait is implemented by all the event formats pub trait EventFormat: Sized { type Data; - fn to_data(self, schema: &Schema) -> Result<(Self::Data, Schema, Tags, Metadata), AnyError>; + fn to_data( + self, + schema: &HashMap, + ) -> Result<(Self::Data, Schema, bool, Tags, Metadata), AnyError>; fn decode(data: Self::Data, schema: Arc) -> Result; - fn into_recordbatch(self, schema: &Schema) -> Result { - let (data, mut schema, tags, metadata) = self.to_data(schema)?; - - match tags_index(&schema) { - Ok(_) => return Err(anyhow!("field {} is a reserved field", DEFAULT_TAGS_KEY)), - Err(index) => { - schema - .fields - .insert(index, Field::new(DEFAULT_TAGS_KEY, DataType::Utf8, true)); - } + fn into_recordbatch( + self, + schema: &HashMap, + ) -> Result<(RecordBatch, bool), AnyError> { + let (data, mut schema, is_first, tags, metadata) = self.to_data(schema)?; + + if get_field(&schema, DEFAULT_TAGS_KEY).is_some() { + return Err(anyhow!("field {} is a reserved field", DEFAULT_TAGS_KEY)); }; - match metadata_index(&schema) { - Ok(_) => { - return Err(anyhow!( - "field {} is a reserved field", - DEFAULT_METADATA_KEY - )) - } - Err(index) => { - schema.fields.insert( - index, - Field::new(DEFAULT_METADATA_KEY, DataType::Utf8, true), - ); - } + if get_field(&schema, DEFAULT_TAGS_KEY).is_some() { + return Err(anyhow!( + "field {} is a reserved field", + DEFAULT_METADATA_KEY + )); }; - match timestamp_index(&schema) { - Ok(_) => { - return Err(anyhow!( - "field {} is a reserved field", - DEFAULT_TIMESTAMP_KEY - )) - } - Err(index) => { - schema.fields.insert( - index, - Field::new( - DEFAULT_TIMESTAMP_KEY, - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - ); - } + if get_field(&schema, DEFAULT_TAGS_KEY).is_some() { + return Err(anyhow!( + "field {} is a reserved field", + DEFAULT_TIMESTAMP_KEY + )); }; + // add the p_timestamp field to the event schema to the 0th index + schema.fields.insert( + 0, + Field::new( + DEFAULT_TIMESTAMP_KEY, + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + ); + + // p_tags and p_metadata are added to the end of the schema + let tags_index = schema.fields.len(); + let metadata_index = tags_index + 1; + schema + .fields + .push(Field::new(DEFAULT_TAGS_KEY, DataType::Utf8, true)); + schema + .fields + .push(Field::new(DEFAULT_METADATA_KEY, DataType::Utf8, true)); + + // prepare the record batch and new fields to be added let schema_ref = Arc::new(schema); let rb = Self::decode(data, Arc::clone(&schema_ref))?; let tags_arr = StringArray::from_iter_values(std::iter::repeat(&tags).take(rb.num_rows())); @@ -90,14 +94,11 @@ pub trait EventFormat: Sized { StringArray::from_iter_values(std::iter::repeat(&metadata).take(rb.num_rows())); let timestamp_array = get_timestamp_array(rb.num_rows()); + // modify the record batch to add fields to respective indexes let rb = utils::arrow::replace_columns( Arc::clone(&schema_ref), rb, - &[ - timestamp_index(&schema_ref).expect("timestamp field exists"), - tags_index(&schema_ref).expect("tags field exists"), - metadata_index(&schema_ref).expect("metadata field exists"), - ], + &[0, tags_index, metadata_index], &[ Arc::new(timestamp_array), Arc::new(tags_arr), @@ -105,28 +106,10 @@ pub trait EventFormat: Sized { ], ); - Ok(rb) + Ok((rb, is_first)) } } -fn tags_index(schema: &Schema) -> Result { - schema - .fields - .binary_search_by_key(&DEFAULT_TAGS_KEY, |field| field.name()) -} - -fn metadata_index(schema: &Schema) -> Result { - schema - .fields - .binary_search_by_key(&DEFAULT_METADATA_KEY, |field| field.name()) -} - -fn timestamp_index(schema: &Schema) -> Result { - schema - .fields - .binary_search_by_key(&DEFAULT_TIMESTAMP_KEY, |field| field.name()) -} - fn get_timestamp_array(size: usize) -> TimestampMillisecondArray { let time = Utc::now(); TimestampMillisecondArray::from_value(time.timestamp_millis(), size) diff --git a/server/src/event/format/json.rs b/server/src/event/format/json.rs index af9c1c874..26b45a0ad 100644 --- a/server/src/event/format/json.rs +++ b/server/src/event/format/json.rs @@ -25,10 +25,10 @@ use arrow_json::reader::{infer_json_schema_from_iterator, Decoder, DecoderOption use arrow_schema::{DataType, Field, Schema}; use datafusion::arrow::util::bit_util::round_upto_multiple_of_64; use serde_json::Value; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use super::EventFormat; -use crate::utils::json::flatten_json_body; +use crate::utils::{arrow::get_field, json::flatten_json_body}; pub struct Event { pub data: Value, @@ -39,36 +39,41 @@ pub struct Event { impl EventFormat for Event { type Data = Vec; + // convert the incoming json to a vector of json values + // also extract the arrow schema, tags and metadata from the incoming json fn to_data( self, - schema: &Schema, - ) -> Result<(Self::Data, Schema, String, String), anyhow::Error> { + schema: &HashMap, + ) -> Result<(Self::Data, Schema, bool, String, String), anyhow::Error> { let data = flatten_json_body(self.data)?; let stream_schema = schema; + // incoming event may be a single json or a json array + // but Data (type defined above) is a vector of json values + // hence we need to convert the incoming event to a vector of json values let value_arr = match data { Value::Array(arr) => arr, value @ Value::Object(_) => vec![value], _ => unreachable!("flatten would have failed beforehand"), }; + // collect all the keys from all the json objects in the request body let fields = collect_keys(value_arr.iter()).expect("fields can be collected from array of objects"); - let schema = match derive_sub_schema(stream_schema.clone(), fields) { + let mut is_first = false; + let schema = match derive_arrow_schema(stream_schema, fields) { Ok(schema) => schema, Err(_) => match infer_json_schema_from_iterator(value_arr.iter().map(Ok)) { - Ok(mut infer_schema) => { - infer_schema - .fields - .sort_by(|field1, field2| Ord::cmp(field1.name(), field2.name())); - - if let Err(err) = - Schema::try_merge(vec![stream_schema.clone(), infer_schema.clone()]) - { + Ok(infer_schema) => { + if let Err(err) = Schema::try_merge(vec![ + Schema::new(stream_schema.values().cloned().collect()), + infer_schema.clone(), + ]) { return Err(anyhow!("Could not merge schema of this event with that of the existing stream. {:?}", err)); } + is_first = true; infer_schema } Err(err) => { @@ -89,9 +94,10 @@ impl EventFormat for Event { )); } - Ok((value_arr, schema, self.tags, self.metadata)) + Ok((value_arr, schema, is_first, self.tags, self.metadata)) } + // Convert the Data type (defined above) to arrow record batch fn decode(data: Self::Data, schema: Arc) -> Result { let array_capacity = round_upto_multiple_of_64(data.len()); let value_iter: &mut (dyn Iterator) = &mut data.into_iter(); @@ -108,50 +114,27 @@ impl EventFormat for Event { } } -// invariants for this to work. -// All fields in existing schema and fields in event are sorted my name lexographically -fn derive_sub_schema(schema: arrow_schema::Schema, fields: Vec<&str>) -> Result { - let fields = derive_subset(schema.fields, fields)?; - Ok(Schema::new(fields)) -} - -fn derive_subset(superset: Vec, subset: Vec<&str>) -> Result, ()> { - let mut superset_idx = 0; - let mut subset_idx = 0; - let mut subset_schema = Vec::with_capacity(subset.len()); - - while superset_idx < superset.len() && subset_idx < subset.len() { - let field = superset[superset_idx].clone(); - let key = subset[subset_idx]; - if field.name() == key { - subset_schema.push(field); - superset_idx += 1; - subset_idx += 1; - } else if field.name().as_str() < key { - superset_idx += 1; - } else { - return Err(()); - } +// Returns arrow schema with the fields that are present in the request body +// This schema is an input to convert the request body to arrow record batch +fn derive_arrow_schema(schema: &HashMap, fields: Vec<&str>) -> Result { + let mut res = Vec::with_capacity(fields.len()); + let fields = fields.into_iter().map(|field_name| schema.get(field_name)); + for field in fields { + let Some(field) = field else { return Err(()) }; + res.push(field.clone()) } - - // error if subset is not exhausted - if subset_idx < subset.len() { - return Err(()); - } - - Ok(subset_schema) + Ok(Schema::new(res)) } -// Must be in sorted order fn collect_keys<'a>(values: impl Iterator) -> Result, ()> { - let mut sorted_keys = Vec::new(); + let mut keys = Vec::new(); for value in values { if let Some(obj) = value.as_object() { for key in obj.keys() { - match sorted_keys.binary_search(&key.as_str()) { + match keys.binary_search(&key.as_str()) { Ok(_) => (), Err(pos) => { - sorted_keys.insert(pos, key.as_str()); + keys.insert(pos, key.as_str()); } } } @@ -159,7 +142,7 @@ fn collect_keys<'a>(values: impl Iterator) -> Result bool { @@ -167,8 +150,7 @@ fn fields_mismatch(schema: &Schema, body: &Value) -> bool { if val.is_null() { continue; } - - let Ok(field) = schema.field_with_name(name) else { return true }; + let Some(field) = get_field(schema, name) else { return true }; if !valid_type(field.data_type(), val) { return true; } @@ -187,6 +169,9 @@ fn valid_type(data_type: &DataType, value: &Value) -> bool { let data_type = field.data_type(); if let Value::Array(arr) = value { for elem in arr { + if elem.is_null() { + continue; + } if !valid_type(data_type, elem) { return false; } @@ -202,6 +187,9 @@ fn valid_type(data_type: &DataType, value: &Value) -> bool { .map(|idx| &fields[idx]); if let Some(field) = field { + if value.is_null() { + continue; + } if !valid_type(field.data_type(), value) { return false; } diff --git a/server/src/event/writer.rs b/server/src/event/writer.rs index 5dca2bbc9..caf5a9246 100644 --- a/server/src/event/writer.rs +++ b/server/src/event/writer.rs @@ -19,7 +19,6 @@ mod file_writer; mod mem_writer; -mod mutable; use std::{ collections::HashMap, @@ -56,11 +55,11 @@ impl StreamWriter { ) -> Result<(), StreamWriterError> { match self { StreamWriter::Mem(mem) => { - mem.push(rb); + mem.push(schema_key, rb); } StreamWriter::Disk(disk, mem) => { disk.push(stream_name, schema_key, &rb)?; - mem.push(rb); + mem.push(schema_key, rb); } } Ok(()) diff --git a/server/src/event/writer/mem_writer.rs b/server/src/event/writer/mem_writer.rs index 7dc7d4203..4bd159f46 100644 --- a/server/src/event/writer/mem_writer.rs +++ b/server/src/event/writer/mem_writer.rs @@ -17,38 +17,44 @@ * */ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, TimestampMillisecondArray}; +use arrow_schema::Schema; +use arrow_select::concat::concat_batches; +use itertools::kmerge_by; use crate::utils::arrow::adapt_batch; -use super::mutable::MutableColumns; - #[derive(Default)] pub struct MemWriter { read_buffer: Vec, - mutable_buffer: MutableColumns, + mutable_buffer: HashMap>, } impl MemWriter { - pub fn push(&mut self, rb: RecordBatch) { + pub fn push(&mut self, schema_key: &str, rb: RecordBatch) { if self.mutable_buffer.len() + rb.num_rows() > N { // init new mutable columns with schema of current - let schema = self.mutable_buffer.current_schema(); - let mut new_mutable_buffer = MutableColumns::default(); - new_mutable_buffer.push(RecordBatch::new_empty(Arc::new(schema))); + let schema = self.current_mutable_schema(); // replace new mutable buffer with current one as that is full - let mutable_buffer = std::mem::replace(&mut self.mutable_buffer, new_mutable_buffer); - let filled_rb = mutable_buffer.into_recordbatch(); - self.read_buffer.push(filled_rb); + let mutable_buffer = std::mem::take(&mut self.mutable_buffer); + let batches = mutable_buffer.values().collect(); + self.read_buffer.push(merge_rb(batches, Arc::new(schema))); + } + + if let Some(buf) = self.mutable_buffer.get_mut(schema_key) { + buf.push(rb); + } else { + self.mutable_buffer.insert(schema_key.to_owned(), vec![rb]); } - self.mutable_buffer.push(rb) } pub fn recordbatch_cloned(&self) -> Vec { let mut read_buffer = self.read_buffer.clone(); - let rb = self.mutable_buffer.recordbatch_cloned(); + let schema = self.current_mutable_schema(); + let batches = self.mutable_buffer.values().collect(); + let rb = merge_rb(batches, Arc::new(schema)); let schema = rb.schema(); if rb.num_rows() > 0 { read_buffer.push(rb) @@ -61,8 +67,10 @@ impl MemWriter { } pub fn finalize(self) -> Vec { + let schema = self.current_mutable_schema(); let mut read_buffer = self.read_buffer; - let rb = self.mutable_buffer.into_recordbatch(); + let batches = self.mutable_buffer.values().collect(); + let rb = merge_rb(batches, Arc::new(schema)); let schema = rb.schema(); if rb.num_rows() > 0 { read_buffer.push(rb) @@ -72,4 +80,39 @@ impl MemWriter { .map(|rb| adapt_batch(&schema, rb)) .collect() } + + fn current_mutable_schema(&self) -> Schema { + Schema::try_merge( + self.mutable_buffer + .values() + .flat_map(|rb| rb.first()) + .map(|rb| rb.schema().as_ref().clone()), + ) + .unwrap() + } +} + +fn merge_rb(rb: Vec<&Vec>, schema: Arc) -> RecordBatch { + let sorted_rb: Vec = kmerge_by(rb, |a: &&RecordBatch, b: &&RecordBatch| { + let a: &TimestampMillisecondArray = a + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let b: &TimestampMillisecondArray = b + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + a.value(0) < b.value(0) + }) + .map(|batch| adapt_batch(&schema, batch.clone())) + .collect(); + + // must be true for this to work + // each rb is of same schema. ( adapt_schema should do this ) + // datatype is same + concat_batches(&schema, sorted_rb.iter()).unwrap() } diff --git a/server/src/event/writer/mutable.rs b/server/src/event/writer/mutable.rs deleted file mode 100644 index 7fc1b9fcc..000000000 --- a/server/src/event/writer/mutable.rs +++ /dev/null @@ -1,1030 +0,0 @@ -/* - * Parseable Server (C) 2022 - 2023 Parseable, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of the - * License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * - * - */ - -use std::{cmp::Ordering, sync::Arc}; - -use arrow_array::{ - builder::{ - BooleanBuilder, Float64Builder, Int64Builder, ListBuilder, StringBuilder, - TimestampMillisecondBuilder, UInt64Builder, - }, - new_null_array, Array, BooleanArray, Float64Array, Int64Array, ListArray, RecordBatch, - StringArray, TimestampMillisecondArray, UInt64Array, -}; -use arrow_schema::{DataType, Field, Schema, TimeUnit}; -use itertools::Itertools; - -macro_rules! nested_list { - ($t:expr) => { - DataType::List(Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", $t, true))), - true, - ))) - }; -} - -macro_rules! unit_list { - ($t:expr) => { - DataType::List(Box::new(Field::new("item", $t, true))) - }; -} - -#[derive(Debug)] -pub enum NestedListBuilder { - Utf8(ListBuilder>), - Boolean(ListBuilder>), - Int64(ListBuilder>), - UInt64(ListBuilder>), - Timestamp(ListBuilder>), - Float64(ListBuilder>), -} - -impl NestedListBuilder { - pub fn new(data_type: &DataType) -> Self { - match data_type { - DataType::Boolean => NestedListBuilder::Boolean(ListBuilder::new(ListBuilder::new( - BooleanBuilder::new(), - ))), - DataType::Int64 => { - NestedListBuilder::Int64(ListBuilder::new(ListBuilder::new(Int64Builder::new()))) - } - DataType::UInt64 => { - NestedListBuilder::UInt64(ListBuilder::new(ListBuilder::new(UInt64Builder::new()))) - } - DataType::Float64 => NestedListBuilder::Float64(ListBuilder::new(ListBuilder::new( - Float64Builder::new(), - ))), - DataType::Timestamp(_, _) => NestedListBuilder::Timestamp(ListBuilder::new( - ListBuilder::new(TimestampMillisecondBuilder::new()), - )), - DataType::Utf8 => { - NestedListBuilder::Utf8(ListBuilder::new(ListBuilder::new(StringBuilder::new()))) - } - _ => unreachable!(), - } - } - - pub fn data_type(&self) -> DataType { - match self { - NestedListBuilder::Utf8(_) => nested_list!(DataType::Utf8), - NestedListBuilder::Boolean(_) => nested_list!(DataType::Boolean), - NestedListBuilder::Int64(_) => nested_list!(DataType::Int64), - NestedListBuilder::UInt64(_) => nested_list!(DataType::UInt64), - NestedListBuilder::Timestamp(_) => { - nested_list!(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - NestedListBuilder::Float64(_) => nested_list!(DataType::Float64), - } - } -} - -#[derive(Debug)] -pub enum UnitListBuilder { - Utf8(ListBuilder), - Boolean(ListBuilder), - Int64(ListBuilder), - UInt64(ListBuilder), - Timestamp(ListBuilder), - Float64(ListBuilder), - List(NestedListBuilder), -} - -impl UnitListBuilder { - pub fn new(data_type: &DataType) -> Self { - match data_type { - DataType::Boolean => UnitListBuilder::Boolean(ListBuilder::new(BooleanBuilder::new())), - DataType::Int64 => UnitListBuilder::Int64(ListBuilder::new(Int64Builder::new())), - DataType::UInt64 => UnitListBuilder::UInt64(ListBuilder::new(UInt64Builder::new())), - DataType::Float64 => UnitListBuilder::Float64(ListBuilder::new(Float64Builder::new())), - DataType::Timestamp(_, _) => { - UnitListBuilder::Timestamp(ListBuilder::new(TimestampMillisecondBuilder::new())) - } - DataType::Utf8 => UnitListBuilder::Utf8(ListBuilder::new(StringBuilder::new())), - DataType::List(field) => { - UnitListBuilder::List(NestedListBuilder::new(field.data_type())) - } - _ => unreachable!(), - } - } - - pub fn data_type(&self) -> DataType { - match self { - UnitListBuilder::Utf8(_) => unit_list!(DataType::Utf8), - UnitListBuilder::Boolean(_) => unit_list!(DataType::Boolean), - UnitListBuilder::Int64(_) => unit_list!(DataType::Int64), - UnitListBuilder::UInt64(_) => unit_list!(DataType::UInt64), - UnitListBuilder::Timestamp(_) => { - unit_list!(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - UnitListBuilder::Float64(_) => unit_list!(DataType::Float64), - UnitListBuilder::List(inner) => inner.data_type(), - } - } -} - -#[derive(Debug)] -pub enum MutableColumnArray { - Utf8(StringBuilder), - Boolean(BooleanBuilder), - Int64(Int64Builder), - UInt64(UInt64Builder), - Timestamp(TimestampMillisecondBuilder), - Float64(Float64Builder), - List(UnitListBuilder), -} - -impl MutableColumnArray { - pub fn new(data_type: &DataType) -> Self { - match data_type { - DataType::Boolean => MutableColumnArray::Boolean(BooleanBuilder::new()), - DataType::Int64 => MutableColumnArray::Int64(Int64Builder::new()), - DataType::UInt64 => MutableColumnArray::UInt64(UInt64Builder::new()), - DataType::Float64 => MutableColumnArray::Float64(Float64Builder::new()), - DataType::Timestamp(_, _) => { - MutableColumnArray::Timestamp(TimestampMillisecondBuilder::new()) - } - DataType::Utf8 => MutableColumnArray::Utf8(StringBuilder::new()), - DataType::List(field) => { - MutableColumnArray::List(UnitListBuilder::new(field.data_type())) - } - _ => unreachable!(), - } - } - - pub fn data_type(&self) -> DataType { - match self { - MutableColumnArray::Utf8(_) => DataType::Utf8, - MutableColumnArray::Boolean(_) => DataType::Boolean, - MutableColumnArray::Int64(_) => DataType::Int64, - MutableColumnArray::UInt64(_) => DataType::UInt64, - MutableColumnArray::Timestamp(_) => DataType::Timestamp(TimeUnit::Millisecond, None), - MutableColumnArray::Float64(_) => DataType::Float64, - MutableColumnArray::List(inner) => inner.data_type(), - } - } - - fn push_nulls(&mut self, n: usize) { - match self { - MutableColumnArray::Utf8(col) => (0..n).for_each(|_| col.append_null()), - MutableColumnArray::Boolean(col) => col.append_nulls(n), - MutableColumnArray::Int64(col) => col.append_nulls(n), - MutableColumnArray::UInt64(col) => col.append_nulls(n), - MutableColumnArray::Timestamp(col) => col.append_nulls(n), - MutableColumnArray::Float64(col) => col.append_nulls(n), - MutableColumnArray::List(col) => match col { - UnitListBuilder::Utf8(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::Boolean(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::Int64(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::UInt64(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::Timestamp(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::Float64(col) => (0..n).for_each(|_| col.append(false)), - UnitListBuilder::List(col) => match col { - NestedListBuilder::Utf8(col) => (0..n).for_each(|_| col.append(false)), - NestedListBuilder::Boolean(col) => (0..n).for_each(|_| col.append(false)), - NestedListBuilder::Int64(col) => (0..n).for_each(|_| col.append(false)), - NestedListBuilder::UInt64(col) => (0..n).for_each(|_| col.append(false)), - NestedListBuilder::Timestamp(col) => (0..n).for_each(|_| col.append(false)), - NestedListBuilder::Float64(col) => (0..n).for_each(|_| col.append(false)), - }, - }, - } - } - - fn cloned_array(&self) -> Arc { - match self { - MutableColumnArray::Utf8(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::Boolean(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::Int64(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::UInt64(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::Timestamp(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::Float64(col) => Arc::new(col.finish_cloned()), - MutableColumnArray::List(col) => match col { - UnitListBuilder::Utf8(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::Boolean(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::Int64(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::UInt64(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::Timestamp(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::Float64(col) => Arc::new(col.finish_cloned()), - UnitListBuilder::List(col) => match col { - NestedListBuilder::Utf8(col) => Arc::new(col.finish_cloned()), - NestedListBuilder::Boolean(col) => Arc::new(col.finish_cloned()), - NestedListBuilder::Int64(col) => Arc::new(col.finish_cloned()), - NestedListBuilder::UInt64(col) => Arc::new(col.finish_cloned()), - NestedListBuilder::Timestamp(col) => Arc::new(col.finish_cloned()), - NestedListBuilder::Float64(col) => Arc::new(col.finish_cloned()), - }, - }, - } - } - - fn into_array(mut self) -> Arc { - match &mut self { - MutableColumnArray::Utf8(col) => Arc::new(col.finish()), - MutableColumnArray::Boolean(col) => Arc::new(col.finish()), - MutableColumnArray::Int64(col) => Arc::new(col.finish()), - MutableColumnArray::UInt64(col) => Arc::new(col.finish()), - MutableColumnArray::Timestamp(col) => Arc::new(col.finish()), - MutableColumnArray::Float64(col) => Arc::new(col.finish()), - MutableColumnArray::List(col) => match col { - UnitListBuilder::Utf8(col) => Arc::new(col.finish()), - UnitListBuilder::Boolean(col) => Arc::new(col.finish()), - UnitListBuilder::Int64(col) => Arc::new(col.finish()), - UnitListBuilder::UInt64(col) => Arc::new(col.finish()), - UnitListBuilder::Timestamp(col) => Arc::new(col.finish()), - UnitListBuilder::Float64(col) => Arc::new(col.finish()), - UnitListBuilder::List(col) => match col { - NestedListBuilder::Utf8(col) => Arc::new(col.finish()), - NestedListBuilder::Boolean(col) => Arc::new(col.finish()), - NestedListBuilder::Int64(col) => Arc::new(col.finish()), - NestedListBuilder::UInt64(col) => Arc::new(col.finish()), - NestedListBuilder::Timestamp(col) => Arc::new(col.finish()), - NestedListBuilder::Float64(col) => Arc::new(col.finish()), - }, - }, - } - } -} - -#[derive(Debug)] -pub struct MutableColumn { - name: String, - column: MutableColumnArray, -} - -impl MutableColumn { - pub fn new(name: String, column: MutableColumnArray) -> Self { - Self { name, column } - } - - pub fn name(&self) -> &str { - &self.name - } - - pub fn feild(&self) -> Field { - Field::new(&self.name, self.column.data_type(), true) - } -} - -#[derive(Debug, Default)] -pub struct MutableColumns { - columns: Vec, - len: usize, -} - -impl MutableColumns { - pub fn push(&mut self, rb: RecordBatch) { - let num_rows = rb.num_rows(); - let schema = rb.schema(); - let rb = schema.fields().iter().zip(rb.columns().iter()); - - // start index map to next location in self columns - let mut index = 0; - 'rb: for (field, arr) in rb { - // for field in rb look at same field in columns or insert. - // fill with null while traversing if rb field name is greater than column name - while let Some(col) = self.columns.get_mut(index) { - match col.name().cmp(field.name()) { - Ordering::Equal => { - update_column(&mut col.column, Arc::clone(arr)); - // goto next field in rb - index += 1; - continue 'rb; - } - Ordering::Greater => { - let mut new_column = MutableColumn::new( - field.name().to_owned(), - MutableColumnArray::new(field.data_type()), - ); - update_column( - &mut new_column.column, - new_null_array(field.data_type(), self.len), - ); - update_column(&mut new_column.column, Arc::clone(arr)); - self.columns.insert(index, new_column); - index += 1; - continue 'rb; - } - Ordering::Less => { - col.column.push_nulls(num_rows); - index += 1; - } - } - } - - // if inner loop finishes this means this column is suppose to be at the end of columns - let mut new_column = MutableColumn::new( - field.name().to_owned(), - MutableColumnArray::new(field.data_type()), - ); - update_column( - &mut new_column.column, - new_null_array(field.data_type(), self.len), - ); - update_column(&mut new_column.column, Arc::clone(arr)); - self.columns.push(new_column); - index += 1; - } - - // fill any columns yet to be updated with nulls - for col in self.columns[index..].iter_mut() { - col.column.push_nulls(num_rows) - } - - self.len += num_rows - } - - pub fn into_recordbatch(self) -> RecordBatch { - let mut fields = Vec::with_capacity(self.columns.len()); - let mut arrays = Vec::with_capacity(self.columns.len()); - - for MutableColumn { name, column } in self.columns { - let field = Field::new(name, column.data_type(), true); - fields.push(field); - arrays.push(column.into_array()); - } - - RecordBatch::try_new(Arc::new(Schema::new(fields)), arrays).unwrap() - } - - pub fn recordbatch_cloned(&self) -> RecordBatch { - let mut fields = Vec::with_capacity(self.columns.len()); - let mut arrays = Vec::with_capacity(self.columns.len()); - - for MutableColumn { name, column } in &self.columns { - let field = Field::new(name, column.data_type(), true); - fields.push(field); - arrays.push(column.cloned_array()); - } - - RecordBatch::try_new(Arc::new(Schema::new(fields)), arrays).unwrap() - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn current_schema(&self) -> Schema { - Schema::new(self.columns.iter().map(|x| x.feild()).collect_vec()) - } -} - -fn update_column(col: &mut MutableColumnArray, arr: Arc) { - match col { - MutableColumnArray::Utf8(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::Boolean(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::Int64(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::UInt64(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::Timestamp(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::Float64(col) => downcast::(&arr) - .iter() - .for_each(|v| col.append_option(v)), - MutableColumnArray::List(col) => match col { - UnitListBuilder::Utf8(col) => { - let arr = into_vec_array(&arr); - let iter = arr - .iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())); - col.extend(iter); - } - UnitListBuilder::Boolean(col) => { - let arr = into_vec_array(&arr); - let iter = arr - .iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())); - col.extend(iter); - } - UnitListBuilder::Int64(col) => { - let arr = into_vec_array(&arr); - let iter = arr - .iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())); - col.extend(iter); - } - UnitListBuilder::UInt64(col) => { - let arr = into_vec_array(&arr); - let iter = arr - .iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())); - col.extend(iter); - } - UnitListBuilder::Timestamp(col) => { - let arr = into_vec_array(&arr); - let iter = arr.iter().map(|x| { - x.as_ref() - .map(|x| downcast::(x).iter()) - }); - col.extend(iter); - } - UnitListBuilder::Float64(col) => { - let arr = into_vec_array(&arr); - let iter = arr - .iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())); - col.extend(iter); - } - UnitListBuilder::List(col) => match col { - NestedListBuilder::Utf8(col) => { - let arr = into_vec_vec_array(&arr); - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())) - }) - }); - - col.extend(iter) - } - NestedListBuilder::Boolean(col) => { - let arr = into_vec_vec_array(&arr); - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())) - }) - }); - - col.extend(iter) - } - NestedListBuilder::Int64(col) => { - let arr = into_vec_vec_array(&arr); - - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())) - }) - }); - - col.extend(iter) - } - NestedListBuilder::UInt64(col) => { - let arr = into_vec_vec_array(&arr); - - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())) - }) - }); - - col.extend(iter) - } - NestedListBuilder::Timestamp(col) => { - let arr = into_vec_vec_array(&arr); - - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter().map(|x| { - x.as_ref() - .map(|x| downcast::(x).iter()) - }) - }) - }); - - col.extend(iter) - } - NestedListBuilder::Float64(col) => { - let arr = into_vec_vec_array(&arr); - - let iter = arr.iter().map(|x| { - x.as_ref().map(|arr| { - arr.iter() - .map(|x| x.as_ref().map(|x| downcast::(x).iter())) - }) - }); - - col.extend(iter) - } - }, - }, - }; -} - -fn downcast(arr: &dyn Array) -> &T { - arr.as_any().downcast_ref::().unwrap() -} - -type VecArray = Vec>>; - -fn into_vec_array(arr: &dyn Array) -> VecArray { - arr.as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect() -} - -fn into_vec_vec_array(arr: &dyn Array) -> Vec> { - arr.as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|arr| { - arr.map(|arr| { - arr.as_any() - .downcast_ref::() - .unwrap() - .iter() - .collect() - }) - }) - .collect() -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow_array::{BooleanArray, RecordBatch}; - use arrow_schema::{DataType, Field, Schema, TimeUnit}; - - use super::{MutableColumnArray, MutableColumns}; - - macro_rules! check_array_builder { - ($t:expr) => { - assert_eq!(MutableColumnArray::new(&$t).data_type(), $t) - }; - } - - macro_rules! check_unit_list_builder { - ($t:expr) => { - assert_eq!( - MutableColumnArray::new(&DataType::List(Box::new(Field::new("item", $t, true)))) - .data_type(), - DataType::List(Box::new(Field::new("item", $t, true))) - ) - }; - } - - macro_rules! check_nested_list_builder { - ($t:expr) => { - assert_eq!( - MutableColumnArray::new(&DataType::List(Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", $t, true))), - true - )))) - .data_type(), - DataType::List(Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", $t, true))), - true - ))) - ) - }; - } - - #[test] - fn create_mutable_col_and_check_datatype() { - check_array_builder!(DataType::Boolean); - check_array_builder!(DataType::Int64); - check_array_builder!(DataType::UInt64); - check_array_builder!(DataType::Float64); - check_array_builder!(DataType::Utf8); - check_array_builder!(DataType::Timestamp(TimeUnit::Millisecond, None)); - check_unit_list_builder!(DataType::Boolean); - check_unit_list_builder!(DataType::Int64); - check_unit_list_builder!(DataType::UInt64); - check_unit_list_builder!(DataType::Float64); - check_unit_list_builder!(DataType::Utf8); - check_unit_list_builder!(DataType::Timestamp(TimeUnit::Millisecond, None)); - check_nested_list_builder!(DataType::Boolean); - check_nested_list_builder!(DataType::Int64); - check_nested_list_builder!(DataType::UInt64); - check_nested_list_builder!(DataType::Float64); - check_nested_list_builder!(DataType::Utf8); - check_nested_list_builder!(DataType::Timestamp(TimeUnit::Millisecond, None)); - } - - #[test] - fn empty_columns_push_single_col() { - let mut columns = MutableColumns::default(); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let col1 = Arc::new(BooleanArray::from(vec![true, false, true])); - let rb = RecordBatch::try_new(Arc::new(schema), vec![col1]).unwrap(); - - columns.push(rb); - - assert_eq!(columns.columns.len(), 1) - } - - #[test] - fn empty_columns_push_empty_rb() { - let mut columns = MutableColumns::default(); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - - columns.push(rb); - - assert_eq!(columns.columns.len(), 1); - assert_eq!(columns.len, 0); - } - - #[test] - fn one_empty_column_push_new_empty_column_before() { - let mut columns = MutableColumns::default(); - - let schema = Schema::new(vec![Field::new("b", DataType::Boolean, true)]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - columns.push(rb); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 0); - } - - #[test] - fn one_column_push_new_column_before() { - let mut columns = MutableColumns::default(); - - let schema = Schema::new(vec![Field::new("b", DataType::Boolean, true)]); - let col2 = Arc::new(BooleanArray::from(vec![true, false, true])); - let rb = RecordBatch::try_new(Arc::new(schema), vec![col2]).unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 1); - assert_eq!(columns.len, 3); - - let MutableColumnArray::Boolean(builder) = &columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish_cloned(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(true), Some(false), Some(true)] - ) - } - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let col1 = Arc::new(BooleanArray::from(vec![true, true, true])); - let rb = RecordBatch::try_new(Arc::new(schema), vec![col1]).unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 6); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![None, None, None, Some(true), Some(true), Some(true)] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(true), Some(false), Some(true), None, None, None] - ) - } - } - - #[test] - fn two_column_push_new_column_before() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("b", DataType::Boolean, true), - Field::new("c", DataType::Boolean, true), - ]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(BooleanArray::from(vec![false, true, false])), - Arc::new(BooleanArray::from(vec![false, false, true])), - ], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 3); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![true, false, false]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 6); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![None, None, None, Some(true), Some(false), Some(false)] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(true), Some(false), None, None, None] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(false), Some(true), None, None, None] - ) - } - } - - #[test] - fn two_column_push_new_column_middle() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("c", DataType::Boolean, true), - ]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(BooleanArray::from(vec![true, false, false])), - Arc::new(BooleanArray::from(vec![false, false, true])), - ], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 3); - - let schema = Schema::new(vec![Field::new("b", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![false, true, false]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 6); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(true), Some(false), Some(false), None, None, None] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![None, None, None, Some(false), Some(true), Some(false)] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(false), Some(true), None, None, None] - ) - } - } - - #[test] - fn two_column_push_new_column_after() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Boolean, true), - ]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(BooleanArray::from(vec![true, false, false])), - Arc::new(BooleanArray::from(vec![false, true, false])), - ], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 3); - - let schema = Schema::new(vec![Field::new("c", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![false, false, true]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 6); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(true), Some(false), Some(false), None, None, None] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(true), Some(false), None, None, None] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![None, None, None, Some(false), Some(false), Some(true)] - ) - } - } - - #[test] - fn two_empty_column_push_new_column_before() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("b", DataType::Boolean, true), - Field::new("c", DataType::Boolean, true), - ]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 0); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![true, false, false]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 3); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(true), Some(false), Some(false)] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - } - - #[test] - fn two_empty_column_push_new_column_middle() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("c", DataType::Boolean, true), - ]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 0); - - let schema = Schema::new(vec![Field::new("b", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![false, true, false]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 3); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(true), Some(false)] - ) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - } - - #[test] - fn two_empty_column_push_new_column_after() { - let mut columns = MutableColumns::default(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Boolean, true), - ]); - let rb = RecordBatch::new_empty(Arc::new(schema)); - columns.push(rb); - - assert_eq!(columns.columns.len(), 2); - assert_eq!(columns.len, 0); - - let schema = Schema::new(vec![Field::new("c", DataType::Boolean, true)]); - let rb = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(BooleanArray::from(vec![false, false, true]))], - ) - .unwrap(); - columns.push(rb); - - assert_eq!(columns.columns.len(), 3); - assert_eq!(columns.len, 3); - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[0].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[1].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!(arr.iter().collect::>(), vec![None, None, None]) - } - - let MutableColumnArray::Boolean(builder) = &mut columns.columns[2].column else {unreachable!()}; - { - let arr = builder.finish(); - assert_eq!( - arr.iter().collect::>(), - vec![Some(false), Some(false), Some(true)] - ) - } - } -} diff --git a/server/src/handlers/http/ingest.rs b/server/src/handlers/http/ingest.rs index 8d3883f1c..8ef22e734 100644 --- a/server/src/handlers/http/ingest.rs +++ b/server/src/handlers/http/ingest.rs @@ -16,9 +16,11 @@ * */ +use std::collections::HashMap; + use actix_web::http::header::ContentType; use actix_web::{HttpRequest, HttpResponse}; -use arrow_schema::Schema; +use arrow_schema::Field; use bytes::Bytes; use http::StatusCode; use serde_json::Value; @@ -27,7 +29,6 @@ use crate::event::error::EventError; use crate::event::format::EventFormat; use crate::event::{self, format}; use crate::handlers::{PREFIX_META, PREFIX_TAGS, SEPARATOR, STREAM_NAME_HEADER_KEY}; -use crate::metadata::error::stream_info::MetadataError; use crate::metadata::STREAM_INFO; use crate::utils::header_parsing::{collect_labelled_headers, ParseHeaderError}; @@ -61,14 +62,21 @@ pub async fn post_event(req: HttpRequest, body: Bytes) -> Result Result<(), PostError> { - let schema = STREAM_INFO.schema(&stream_name)?; - let (size, rb) = into_event_batch(req, body, &schema)?; + let (size, rb, is_first_event) = { + let hash_map = STREAM_INFO.read().unwrap(); + let schema = &hash_map + .get(&stream_name) + .ok_or(PostError::StreamNotFound(stream_name.clone()))? + .schema; + into_event_batch(req, body, schema)? + }; event::Event { rb, stream_name, origin_format: "json", origin_size: size as u64, + is_first_event, } .process() .await?; @@ -76,12 +84,11 @@ async fn push_logs(stream_name: String, req: HttpRequest, body: Bytes) -> Result Ok(()) } -// This function is decoupled from handler itself for testing purpose fn into_event_batch( req: HttpRequest, body: Bytes, - schema: &Schema, -) -> Result<(usize, arrow_array::RecordBatch), PostError> { + schema: &HashMap, +) -> Result<(usize, arrow_array::RecordBatch, bool), PostError> { let tags = collect_labelled_headers(&req, PREFIX_TAGS, SEPARATOR)?; let metadata = collect_labelled_headers(&req, PREFIX_META, SEPARATOR)?; let size = body.len(); @@ -91,14 +98,14 @@ fn into_event_batch( tags, metadata, }; - let rb = event.into_recordbatch(schema)?; - Ok((size, rb)) + let (rb, is_first) = event.into_recordbatch(schema)?; + Ok((size, rb, is_first)) } #[derive(Debug, thiserror::Error)] pub enum PostError { #[error("{0}")] - StreamNotFound(#[from] MetadataError), + StreamNotFound(String), #[error("Could not deserialize into JSON object, {0}")] SerdeError(#[from] serde_json::Error), #[error("Header Error: {0}")] @@ -133,11 +140,13 @@ impl actix_web::ResponseError for PostError { #[cfg(test)] mod tests { + use std::collections::HashMap; + use actix_web::test::TestRequest; use arrow_array::{ types::Int64Type, ArrayRef, Float64Array, Int64Array, ListArray, StringArray, }; - use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::{DataType, Field}; use bytes::Bytes; use serde_json::json; @@ -181,23 +190,26 @@ mod tests { .append_header((PREFIX_META.to_string() + "C", "meta1")) .to_http_request(); - let (size, rb) = into_event_batch( + let (size, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .unwrap(); assert_eq!(size, 28); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 6); - assert_eq!(rb.column(0).as_int64_arr(), &Int64Array::from_iter([1])); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), + &Int64Array::from_iter([1]) + ); + assert_eq!( + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from_iter_values(["hello"]) ); assert_eq!( - rb.column(2).as_float64_arr(), + rb.column_by_name("c").unwrap().as_float64_arr(), &Float64Array::from_iter([4.23]) ); assert_eq!( @@ -224,18 +236,21 @@ mod tests { let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .unwrap(); assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 5); - assert_eq!(rb.column(0).as_int64_arr(), &Int64Array::from_iter([1])); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), + &Int64Array::from_iter([1]) + ); + assert_eq!( + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from_iter_values(["hello"]) ); } @@ -247,15 +262,15 @@ mod tests { "b": "hello", }); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), + let schema = HashMap::from([ + ("a".to_string(), Field::new("a", DataType::Int64, true)), + ("b".to_string(), Field::new("b", DataType::Utf8, true)), + ("c".to_string(), Field::new("c", DataType::Float64, true)), ]); let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), &schema, @@ -264,9 +279,12 @@ mod tests { assert_eq!(rb.num_rows(), 1); assert_eq!(rb.num_columns(), 5); - assert_eq!(rb.column(0).as_int64_arr(), &Int64Array::from_iter([1])); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), + &Int64Array::from_iter([1]) + ); + assert_eq!( + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from_iter_values(["hello"]) ); } @@ -278,10 +296,10 @@ mod tests { "b": 1, // type mismatch }); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), + let schema = HashMap::from([ + ("a".to_string(), Field::new("a", DataType::Int64, true)), + ("b".to_string(), Field::new("b", DataType::Utf8, true)), + ("c".to_string(), Field::new("c", DataType::Float64, true)), ]); let req = TestRequest::default().to_http_request(); @@ -298,15 +316,15 @@ mod tests { fn empty_object() { let json = json!({}); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Float64, true), - Field::new("c", DataType::Float64, true), + let schema = HashMap::from([ + ("a".to_string(), Field::new("a", DataType::Int64, true)), + ("b".to_string(), Field::new("b", DataType::Utf8, true)), + ("c".to_string(), Field::new("c", DataType::Float64, true)), ]); let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), &schema, @@ -326,7 +344,7 @@ mod tests { assert!(into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .is_err()) } @@ -353,25 +371,25 @@ mod tests { let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .unwrap(); assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 6); assert_eq!( - rb.column(0).as_int64_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), &Int64Array::from(vec![None, Some(1), Some(1)]) ); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from(vec![Some("hello"), Some("hello"), Some("hello"),]) ); assert_eq!( - rb.column(2).as_float64_arr(), + rb.column_by_name("c").unwrap().as_float64_arr(), &Float64Array::from(vec![None, Some(1.22), None,]) ); } @@ -396,15 +414,14 @@ mod tests { }, ]); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), + let schema = HashMap::from([ + ("a".to_string(), Field::new("a", DataType::Int64, true)), + ("b".to_string(), Field::new("b", DataType::Utf8, true)), + ("c".to_string(), Field::new("c", DataType::Float64, true)), ]); - let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), &schema, @@ -414,15 +431,15 @@ mod tests { assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 6); assert_eq!( - rb.column(0).as_int64_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), &Int64Array::from(vec![None, Some(1), Some(1)]) ); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from(vec![Some("hello"), Some("hello"), Some("hello"),]) ); assert_eq!( - rb.column(2).as_float64_arr(), + rb.column_by_name("c").unwrap().as_float64_arr(), &Float64Array::from(vec![None, Some(1.22), None,]) ); } @@ -449,21 +466,21 @@ mod tests { let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .unwrap(); assert_eq!(rb.num_rows(), 3); assert_eq!(rb.num_columns(), 5); assert_eq!( - rb.column(0).as_int64_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), &Int64Array::from(vec![Some(1), Some(1), Some(1)]) ); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from(vec![Some("hello"), Some("hello"), Some("hello"),]) ); } @@ -490,10 +507,10 @@ mod tests { let req = TestRequest::default().to_http_request(); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), + let schema = HashMap::from([ + ("a".to_string(), Field::new("a", DataType::Int64, true)), + ("b".to_string(), Field::new("b", DataType::Utf8, true)), + ("c".to_string(), Field::new("c", DataType::Float64, true)), ]); assert!(into_event_batch( @@ -529,21 +546,21 @@ mod tests { let req = TestRequest::default().to_http_request(); - let (_, rb) = into_event_batch( + let (_, rb, _) = into_event_batch( req, Bytes::from(serde_json::to_vec(&json).unwrap()), - &Schema::empty(), + &HashMap::default(), ) .unwrap(); assert_eq!(rb.num_rows(), 4); assert_eq!(rb.num_columns(), 7); assert_eq!( - rb.column(0).as_int64_arr(), + rb.column_by_name("a").unwrap().as_int64_arr(), &Int64Array::from(vec![Some(1), Some(1), Some(1), Some(1)]) ); assert_eq!( - rb.column(1).as_utf8_arr(), + rb.column_by_name("b").unwrap().as_utf8_arr(), &StringArray::from(vec![ Some("hello"), Some("hello"), @@ -556,12 +573,20 @@ mod tests { let c_b = vec![None, None, None, Some(vec![Some(2i64)])]; assert_eq!( - rb.column(2).as_any().downcast_ref::().unwrap(), + rb.column_by_name("c_a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), &ListArray::from_iter_primitive::(c_a) ); assert_eq!( - rb.column(3).as_any().downcast_ref::().unwrap(), + rb.column_by_name("c_b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(), &ListArray::from_iter_primitive::(c_b) ); } diff --git a/server/src/metadata.rs b/server/src/metadata.rs index 6ee2cdeeb..a6724c378 100644 --- a/server/src/metadata.rs +++ b/server/src/metadata.rs @@ -17,7 +17,8 @@ */ use arrow_array::RecordBatch; -use arrow_schema::Schema; +use arrow_schema::{Field, Schema}; +use itertools::Itertools; use once_cell::sync::Lazy; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -37,21 +38,12 @@ pub static STREAM_INFO: Lazy = Lazy::new(StreamInfo::default); #[derive(Debug, Deref, DerefMut, Default)] pub struct StreamInfo(RwLock>); -#[derive(Debug)] +#[derive(Debug, Default)] pub struct LogStreamMetadata { - pub schema: Arc, + pub schema: HashMap, pub alerts: Alerts, } -impl Default for LogStreamMetadata { - fn default() -> Self { - Self { - schema: Arc::new(Schema::empty()), - alerts: Alerts::default(), - } - } -} - // It is very unlikely that panic will occur when dealing with metadata. pub const LOCK_EXPECT: &str = "no method in metadata should panic while holding a lock"; @@ -95,7 +87,17 @@ impl StreamInfo { .ok_or(MetadataError::StreamMetaNotFound(stream_name.to_string())) .map(|metadata| &metadata.schema)?; - Ok(Arc::clone(schema)) + // sort fields on read from hashmap as order of fields can differ. + // This provides a stable output order if schema is same between calls to this function + let fields = schema + .values() + .sorted_by_key(|field| field.name()) + .cloned() + .collect(); + + let schema = Schema::new(fields); + + Ok(Arc::new(schema)) } pub fn set_alert(&self, stream_name: &str, alerts: Alerts) -> Result<(), MetadataError> { @@ -130,8 +132,9 @@ impl StreamInfo { let alerts = storage.get_alerts(&stream.name).await?; let schema = storage.get_schema(&stream.name).await?; - let schema = Arc::new(update_schema_from_staging(&stream.name, schema)); - + let schema = update_schema_from_staging(&stream.name, schema); + let schema = + HashMap::from_iter(schema.fields.into_iter().map(|v| (v.name().to_owned(), v))); let metadata = LogStreamMetadata { schema, alerts }; let mut map = self.write().expect(LOCK_EXPECT); diff --git a/server/src/utils/arrow.rs b/server/src/utils/arrow.rs index 6b8dff675..812151844 100644 --- a/server/src/utils/arrow.rs +++ b/server/src/utils/arrow.rs @@ -80,3 +80,7 @@ mod tests { assert_eq!(new_rb.num_rows(), 3) } } + +pub fn get_field<'a>(schema: &'a Schema, name: &str) -> Option<&'a arrow_schema::Field> { + schema.fields.iter().find(|field| field.name() == name) +} diff --git a/server/src/utils/arrow/batch_adapter.rs b/server/src/utils/arrow/batch_adapter.rs index 3510547e2..5f3a91935 100644 --- a/server/src/utils/arrow/batch_adapter.rs +++ b/server/src/utils/arrow/batch_adapter.rs @@ -23,12 +23,18 @@ use datafusion::arrow::record_batch::RecordBatch; use std::sync::Arc; +// This function takes a new event's record batch and the +// current schema of the log stream. It returns a new record +// with nulls added to the fields that don't exist +// in the record batch (i.e. the event) but are present in the +// log stream schema. +// This is necessary because all the record batches in a log +// stream need to have all the fields. pub fn adapt_batch(table_schema: &Schema, batch: RecordBatch) -> RecordBatch { let batch_schema = &*batch.schema(); - - let mut cols: Vec = Vec::with_capacity(table_schema.fields().len()); let batch_cols = batch.columns().to_vec(); + let mut cols: Vec = Vec::with_capacity(table_schema.fields().len()); for table_field in table_schema.fields() { if let Some((batch_idx, _)) = batch_schema.column_with_name(table_field.name().as_str()) { cols.push(Arc::clone(&batch_cols[batch_idx])); @@ -38,6 +44,5 @@ pub fn adapt_batch(table_schema: &Schema, batch: RecordBatch) -> RecordBatch { } let merged_schema = Arc::new(table_schema.clone()); - RecordBatch::try_new(merged_schema, cols).unwrap() } diff --git a/server/src/utils/arrow/merged_reader.rs b/server/src/utils/arrow/merged_reader.rs index 4c9feb5f7..3ce701193 100644 --- a/server/src/utils/arrow/merged_reader.rs +++ b/server/src/utils/arrow/merged_reader.rs @@ -25,31 +25,11 @@ use arrow_schema::Schema; use itertools::kmerge_by; use super::adapt_batch; - -#[derive(Debug)] -pub struct Reader { - reader: StreamReader, - timestamp_col_index: usize, -} - -impl From> for Reader { - fn from(reader: StreamReader) -> Self { - let timestamp_col_index = reader - .schema() - .all_fields() - .binary_search_by(|field| field.name().as_str().cmp("p_timestamp")) - .expect("schema should have this field"); - - Self { - reader, - timestamp_col_index, - } - } -} +use crate::event::DEFAULT_TIMESTAMP_KEY; #[derive(Debug)] pub struct MergedRecordReader { - pub readers: Vec, + pub readers: Vec>, } impl MergedRecordReader { @@ -58,47 +38,48 @@ impl MergedRecordReader { for file in files { let reader = StreamReader::try_new(File::open(file).unwrap(), None).map_err(|_| ())?; - readers.push(reader.into()); + readers.push(reader); } Ok(Self { readers }) } pub fn merged_iter(self, schema: &Schema) -> impl Iterator + '_ { - let adapted_readers = self.readers.into_iter().map(move |reader| { - reader - .reader - .flatten() - .zip(std::iter::repeat(reader.timestamp_col_index)) - }); - - kmerge_by( - adapted_readers, - |(a, a_col): &(RecordBatch, usize), (b, b_col): &(RecordBatch, usize)| { - let a: &TimestampMillisecondArray = a - .column(*a_col) - .as_any() - .downcast_ref::() - .unwrap(); + let adapted_readers = self.readers.into_iter().map(move |reader| reader.flatten()); - let b: &TimestampMillisecondArray = b - .column(*b_col) - .as_any() - .downcast_ref::() - .unwrap(); - - a.value(0) < b.value(0) - }, - ) - .map(|(batch, _)| adapt_batch(schema, batch)) + kmerge_by(adapted_readers, |a: &RecordBatch, b: &RecordBatch| { + let a_time = get_timestamp_millis(a); + let b_time = get_timestamp_millis(b); + a_time < b_time + }) + .map(|batch| adapt_batch(schema, batch)) } pub fn merged_schema(&self) -> Schema { Schema::try_merge( self.readers .iter() - .map(|reader| reader.reader.schema().as_ref().clone()), + .map(|reader| reader.schema().as_ref().clone()), ) .unwrap() } } + +fn get_timestamp_millis(batch: &RecordBatch) -> i64 { + match batch + .column(0) + .as_any() + .downcast_ref::() + { + // Ideally we expect the first column to be a timestamp (because we add the timestamp column first in the writer) + Some(array) => array.value(0), + // In case the first column is not a timestamp, we fallback to look for default timestamp column across all columns + None => batch + .column_by_name(DEFAULT_TIMESTAMP_KEY) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + } +}