Skip to content

fix: data type for static schema #1235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 105 additions & 46 deletions src/event/format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use anyhow::anyhow;
use arrow_array::RecordBatch;
use arrow_json::reader::{infer_json_schema_from_iterator, ReaderBuilder};
use arrow_schema::{DataType, Field, Fields, Schema};
use chrono::{DateTime, NaiveDateTime, Utc};
use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc};
use datafusion::arrow::util::bit_util::round_upto_multiple_of_64;
use itertools::Itertools;
use serde_json::Value;
Expand Down Expand Up @@ -62,6 +62,7 @@ impl EventFormat for Event {
schema: &HashMap<String, Arc<Field>>,
time_partition: Option<&String>,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> Result<(Self::Data, Vec<Arc<Field>>, bool), anyhow::Error> {
let stream_schema = schema;

Expand Down Expand Up @@ -111,7 +112,7 @@ impl EventFormat for Event {

if value_arr
.iter()
.any(|value| fields_mismatch(&schema, value, schema_version))
.any(|value| fields_mismatch(&schema, value, schema_version, static_schema_flag))
{
return Err(anyhow!(
"Could not process this event due to mismatch in datatype"
Expand Down Expand Up @@ -253,73 +254,131 @@ fn collect_keys<'a>(values: impl Iterator<Item = &'a Value>) -> Result<Vec<&'a s
Ok(keys)
}

fn fields_mismatch(schema: &[Arc<Field>], body: &Value, schema_version: SchemaVersion) -> bool {
fn fields_mismatch(
schema: &[Arc<Field>],
body: &Value,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> bool {
for (name, val) in body.as_object().expect("body is of object variant") {
if val.is_null() {
continue;
}
let Some(field) = get_field(schema, name) else {
return true;
};
if !valid_type(field.data_type(), val, schema_version) {
if !valid_type(field, val, schema_version, static_schema_flag) {
return true;
}
}
false
}

fn valid_type(data_type: &DataType, value: &Value, schema_version: SchemaVersion) -> bool {
match data_type {
fn valid_type(
field: &Field,
value: &Value,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> bool {
match field.data_type() {
DataType::Boolean => value.is_boolean(),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => value.is_i64(),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
validate_int(value, static_schema_flag)
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => value.is_u64(),
DataType::Float16 | DataType::Float32 => value.is_f64(),
// All numbers can be cast as Float64 from schema version v1
DataType::Float64 if schema_version == SchemaVersion::V1 => value.is_number(),
DataType::Float64 if schema_version != SchemaVersion::V1 => value.is_f64(),
DataType::Float64 => validate_float(value, schema_version, static_schema_flag),
DataType::Utf8 => value.is_string(),
DataType::List(field) => {
let data_type = field.data_type();
if let Value::Array(arr) = value {
for elem in arr {
if elem.is_null() {
continue;
}
if !valid_type(data_type, elem, schema_version) {
return false;
}
}
}
true
}
DataType::List(field) => validate_list(field, value, schema_version, static_schema_flag),
DataType::Struct(fields) => {
if let Value::Object(val) = value {
for (key, value) in val {
let field = (0..fields.len())
.find(|idx| fields[*idx].name() == key)
.map(|idx| &fields[idx]);

if let Some(field) = field {
if value.is_null() {
continue;
}
if !valid_type(field.data_type(), value, schema_version) {
return false;
}
} else {
return false;
}
}
true
} else {
false
validate_struct(fields, value, schema_version, static_schema_flag)
}
DataType::Date32 => {
if let Value::String(s) = value {
return NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok();
}
false
}
DataType::Timestamp(_, _) => value.is_string() || value.is_number(),
_ => {
error!("Unsupported datatype {:?}, value {:?}", data_type, value);
unreachable!()
error!(
"Unsupported datatype {:?}, value {:?}",
field.data_type(),
value
);
false
}
}
}

fn validate_int(value: &Value, static_schema_flag: bool) -> bool {
// allow casting string to int for static schema
if static_schema_flag {
if let Value::String(s) = value {
return s.trim().parse::<i64>().is_ok();
}
}
value.is_i64()
}

fn validate_float(value: &Value, schema_version: SchemaVersion, static_schema_flag: bool) -> bool {
// allow casting string to int for static schema
if static_schema_flag {
if let Value::String(s) = value.clone() {
let trimmed = s.trim();
return trimmed.parse::<f64>().is_ok() || trimmed.parse::<i64>().is_ok();
}
return value.is_number();
}
match schema_version {
SchemaVersion::V1 => value.is_number(),
_ => value.is_f64(),
}
}

fn validate_list(
field: &Field,
value: &Value,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> bool {
if let Value::Array(arr) = value {
for elem in arr {
if elem.is_null() {
continue;
}
if !valid_type(field, elem, schema_version, static_schema_flag) {
return false;
}
}
}
true
}

fn validate_struct(
fields: &Fields,
value: &Value,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> bool {
if let Value::Object(val) = value {
for (key, value) in val {
let field = fields.iter().find(|f| f.name() == key);

if let Some(field) = field {
if value.is_null() {
continue;
}
if !valid_type(field, value, schema_version, static_schema_flag) {
return false;
}
} else {
return false;
}
}
true
} else {
false
}
}

Expand Down
9 changes: 7 additions & 2 deletions src/event/format/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pub trait EventFormat: Sized {
schema: &HashMap<String, Arc<Field>>,
time_partition: Option<&String>,
schema_version: SchemaVersion,
static_schema_flag: bool,
) -> Result<(Self::Data, EventSchema, bool), AnyError>;

fn decode(data: Self::Data, schema: Arc<Schema>) -> Result<RecordBatch, AnyError>;
Expand All @@ -117,8 +118,12 @@ pub trait EventFormat: Sized {
schema_version: SchemaVersion,
) -> Result<(RecordBatch, bool), AnyError> {
let p_timestamp = self.get_p_timestamp();
let (data, mut schema, is_first) =
self.to_data(storage_schema, time_partition, schema_version)?;
let (data, mut schema, is_first) = self.to_data(
storage_schema,
time_partition,
schema_version,
static_schema_flag,
)?;

if get_field(&schema, DEFAULT_TIMESTAMP_KEY).is_some() {
return Err(anyhow!(
Expand Down
1 change: 1 addition & 0 deletions src/query/stream_schema_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ fn cast_or_none(scalar: &ScalarValue) -> Option<CastRes<'_>> {
ScalarValue::UInt32(val) => val.map(|val| CastRes::Int(val as i64)),
ScalarValue::UInt64(val) => val.map(|val| CastRes::Int(val as i64)),
ScalarValue::Utf8(val) => val.as_ref().map(|val| CastRes::String(val)),
ScalarValue::Date32(val) => val.map(|val| CastRes::Int(val as i64)),
ScalarValue::TimestampMillisecond(val, _) => val.map(CastRes::Int),
_ => None,
}
Expand Down
1 change: 1 addition & 0 deletions src/static_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pub fn convert_static_schema_to_arrow_schema(
"boolean" => DataType::Boolean,
"string" => DataType::Utf8,
"datetime" => DataType::Timestamp(TimeUnit::Millisecond, None),
"date" => DataType::Date32,
"string_list" => {
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))
}
Expand Down
Loading