Skip to content

Commit 8298695

Browse files
vaibhawvipulSteve Vaughan Jr
authored andcommitted
feat: Implement Spark-compatible CAST from string to timestamp types (apache#335)
* casting str to timestamp * fix format * fixing failed tests, using char as pattern * bug fixes * hangling microsecond * make format * bug fixes and core refactor * format code * removing print statements * clippy error * enabling cast timestamp test case * code refactor * comet spark test case * adding all the supported format in test * fallback spark when timestamp not utc * bug fix * bug fix * adding an explainer commit * fix test case * bug fix * bug fix * better error handling for unwrap in fn parse_str_to_time_only_timestamp * remove unwrap from macro * improving error handling * adding tests for invalid inputs * removed all unwraps from timestamp cast functions * code format
1 parent b48f8f3 commit 8298695

File tree

3 files changed

+391
-8
lines changed

3 files changed

+391
-8
lines changed

core/src/execution/datafusion/expressions/cast.rs

Lines changed: 314 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use std::{
2525
use crate::errors::{CometError, CometResult};
2626
use arrow::{
2727
compute::{cast_with_options, CastOptions},
28+
datatypes::TimestampMicrosecondType,
2829
record_batch::RecordBatch,
2930
util::display::FormatOptions,
3031
};
@@ -33,10 +34,12 @@ use arrow_array::{
3334
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
3435
};
3536
use arrow_schema::{DataType, Schema};
37+
use chrono::{TimeZone, Timelike};
3638
use datafusion::logical_expr::ColumnarValue;
3739
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
3840
use datafusion_physical_expr::PhysicalExpr;
3941
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
42+
use regex::Regex;
4043

4144
use crate::execution::datafusion::expressions::utils::{
4245
array_with_timezone, down_cast_any_ref, spark_cast,
@@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int {
8689
}};
8790
}
8891

92+
macro_rules! cast_utf8_to_timestamp {
93+
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
94+
let len = $array.len();
95+
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
96+
for i in 0..len {
97+
if $array.is_null(i) {
98+
cast_array.append_null()
99+
} else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) {
100+
cast_array.append_value(cast_value);
101+
} else {
102+
cast_array.append_null()
103+
}
104+
}
105+
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
106+
result
107+
}};
108+
}
109+
89110
impl Cast {
90111
pub fn new(
91112
child: Arc<dyn PhysicalExpr>,
@@ -125,6 +146,9 @@ impl Cast {
125146
(DataType::LargeUtf8, DataType::Boolean) => {
126147
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
127148
}
149+
(DataType::Utf8, DataType::Timestamp(_, _)) => {
150+
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
151+
}
128152
(
129153
DataType::Utf8,
130154
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
@@ -200,6 +224,30 @@ impl Cast {
200224
Ok(cast_array)
201225
}
202226

227+
fn cast_string_to_timestamp(
228+
array: &ArrayRef,
229+
to_type: &DataType,
230+
eval_mode: EvalMode,
231+
) -> CometResult<ArrayRef> {
232+
let string_array = array
233+
.as_any()
234+
.downcast_ref::<GenericStringArray<i32>>()
235+
.expect("Expected a string array");
236+
237+
let cast_array: ArrayRef = match to_type {
238+
DataType::Timestamp(_, _) => {
239+
cast_utf8_to_timestamp!(
240+
string_array,
241+
eval_mode,
242+
TimestampMicrosecondType,
243+
timestamp_parser
244+
)
245+
}
246+
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
247+
};
248+
Ok(cast_array)
249+
}
250+
203251
fn spark_cast_utf8_to_boolean<OffsetSize>(
204252
from: &dyn Array,
205253
eval_mode: EvalMode,
@@ -510,9 +558,273 @@ impl PhysicalExpr for Cast {
510558
}
511559
}
512560

561+
fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> {
562+
let value = value.trim();
563+
if value.is_empty() {
564+
return Ok(None);
565+
}
566+
// Define regex patterns and corresponding parsing functions
567+
let patterns = &[
568+
(
569+
Regex::new(r"^\d{4}$").unwrap(),
570+
parse_str_to_year_timestamp as fn(&str) -> CometResult<Option<i64>>,
571+
),
572+
(
573+
Regex::new(r"^\d{4}-\d{2}$").unwrap(),
574+
parse_str_to_month_timestamp,
575+
),
576+
(
577+
Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(),
578+
parse_str_to_day_timestamp,
579+
),
580+
(
581+
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
582+
parse_str_to_hour_timestamp,
583+
),
584+
(
585+
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
586+
parse_str_to_minute_timestamp,
587+
),
588+
(
589+
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
590+
parse_str_to_second_timestamp,
591+
),
592+
(
593+
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
594+
parse_str_to_microsecond_timestamp,
595+
),
596+
(
597+
Regex::new(r"^T\d{1,2}$").unwrap(),
598+
parse_str_to_time_only_timestamp,
599+
),
600+
];
601+
602+
let mut timestamp = None;
603+
604+
// Iterate through patterns and try matching
605+
for (pattern, parse_func) in patterns {
606+
if pattern.is_match(value) {
607+
timestamp = parse_func(value)?;
608+
break;
609+
}
610+
}
611+
612+
if timestamp.is_none() {
613+
if eval_mode == EvalMode::Ansi {
614+
return Err(CometError::CastInvalidValue {
615+
value: value.to_string(),
616+
from_type: "STRING".to_string(),
617+
to_type: "TIMESTAMP".to_string(),
618+
});
619+
} else {
620+
return Ok(None);
621+
}
622+
}
623+
624+
match timestamp {
625+
Some(ts) => Ok(Some(ts)),
626+
None => Err(CometError::Internal(
627+
"Failed to parse timestamp".to_string(),
628+
)),
629+
}
630+
}
631+
632+
fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult<Option<i64>> {
633+
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0);
634+
635+
// Check if datetime is not None
636+
let utc_datetime = match datetime.single() {
637+
Some(dt) => dt.with_timezone(&chrono::Utc),
638+
None => {
639+
return Err(CometError::Internal(
640+
"Failed to parse timestamp".to_string(),
641+
));
642+
}
643+
};
644+
645+
Ok(Some(utc_datetime.timestamp_micros()))
646+
}
647+
648+
fn parse_hms_timestamp(
649+
year: i32,
650+
month: u32,
651+
day: u32,
652+
hour: u32,
653+
minute: u32,
654+
second: u32,
655+
microsecond: u32,
656+
) -> CometResult<Option<i64>> {
657+
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second);
658+
659+
// Check if datetime is not None
660+
let utc_datetime = match datetime.single() {
661+
Some(dt) => dt
662+
.with_timezone(&chrono::Utc)
663+
.with_nanosecond(microsecond * 1000),
664+
None => {
665+
return Err(CometError::Internal(
666+
"Failed to parse timestamp".to_string(),
667+
));
668+
}
669+
};
670+
671+
let result = match utc_datetime {
672+
Some(dt) => dt.timestamp_micros(),
673+
None => {
674+
return Err(CometError::Internal(
675+
"Failed to parse timestamp".to_string(),
676+
));
677+
}
678+
};
679+
680+
Ok(Some(result))
681+
}
682+
683+
fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option<i64>> {
684+
let values: Vec<_> = value
685+
.split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
686+
.collect();
687+
let year = values[0].parse::<i32>().unwrap_or_default();
688+
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
689+
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
690+
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
691+
let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
692+
let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
693+
let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));
694+
695+
match timestamp_type {
696+
"year" => parse_ymd_timestamp(year, 1, 1),
697+
"month" => parse_ymd_timestamp(year, month, 1),
698+
"day" => parse_ymd_timestamp(year, month, day),
699+
"hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0),
700+
"minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
701+
"second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0),
702+
"microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond),
703+
_ => Err(CometError::CastInvalidValue {
704+
value: value.to_string(),
705+
from_type: "STRING".to_string(),
706+
to_type: "TIMESTAMP".to_string(),
707+
}),
708+
}
709+
}
710+
711+
fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
712+
get_timestamp_values(value, "year")
713+
}
714+
715+
fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
716+
get_timestamp_values(value, "month")
717+
}
718+
719+
fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
720+
get_timestamp_values(value, "day")
721+
}
722+
723+
fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
724+
get_timestamp_values(value, "hour")
725+
}
726+
727+
fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
728+
get_timestamp_values(value, "minute")
729+
}
730+
731+
fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
732+
get_timestamp_values(value, "second")
733+
}
734+
735+
fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>> {
736+
get_timestamp_values(value, "microsecond")
737+
}
738+
739+
fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
740+
let values: Vec<&str> = value.split('T').collect();
741+
let time_values: Vec<u32> = values[1]
742+
.split(':')
743+
.map(|v| v.parse::<u32>().unwrap_or(0))
744+
.collect();
745+
746+
let datetime = chrono::Utc::now();
747+
let timestamp = datetime
748+
.with_hour(time_values.first().copied().unwrap_or_default())
749+
.and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
750+
.and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
751+
.and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000))
752+
.map(|dt| dt.to_utc().timestamp_micros())
753+
.unwrap_or_default();
754+
755+
Ok(Some(timestamp))
756+
}
757+
513758
#[cfg(test)]
514-
mod test {
515-
use super::{cast_string_to_i8, EvalMode};
759+
mod tests {
760+
use super::*;
761+
use arrow::datatypes::TimestampMicrosecondType;
762+
use arrow_array::StringArray;
763+
use arrow_schema::TimeUnit;
764+
765+
#[test]
766+
fn timestamp_parser_test() {
767+
// write for all formats
768+
assert_eq!(
769+
timestamp_parser("2020", EvalMode::Legacy).unwrap(),
770+
Some(1577836800000000) // this is in milliseconds
771+
);
772+
assert_eq!(
773+
timestamp_parser("2020-01", EvalMode::Legacy).unwrap(),
774+
Some(1577836800000000)
775+
);
776+
assert_eq!(
777+
timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(),
778+
Some(1577836800000000)
779+
);
780+
assert_eq!(
781+
timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(),
782+
Some(1577880000000000)
783+
);
784+
assert_eq!(
785+
timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(),
786+
Some(1577882040000000)
787+
);
788+
assert_eq!(
789+
timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(),
790+
Some(1577882096000000)
791+
);
792+
assert_eq!(
793+
timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(),
794+
Some(1577882096123456)
795+
);
796+
// assert_eq!(
797+
// timestamp_parser("T2", EvalMode::Legacy).unwrap(),
798+
// Some(1714356000000000) // this value needs to change everyday.
799+
// );
800+
}
801+
802+
#[test]
803+
fn test_cast_string_to_timestamp() {
804+
let array: ArrayRef = Arc::new(StringArray::from(vec![
805+
Some("2020-01-01T12:34:56.123456"),
806+
Some("T2"),
807+
]));
808+
809+
let string_array = array
810+
.as_any()
811+
.downcast_ref::<GenericStringArray<i32>>()
812+
.expect("Expected a string array");
813+
814+
let eval_mode = EvalMode::Legacy;
815+
let result = cast_utf8_to_timestamp!(
816+
&string_array,
817+
eval_mode,
818+
TimestampMicrosecondType,
819+
timestamp_parser
820+
);
821+
822+
assert_eq!(
823+
result.data_type(),
824+
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
825+
);
826+
assert_eq!(result.len(), 2);
827+
}
516828

517829
#[test]
518830
fn test_cast_string_as_i8() {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
585585
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
586586
evalMode.toString
587587
}
588+
589+
val supportedTimezone = (child.dataType, dt) match {
590+
case (DataTypes.StringType, DataTypes.TimestampType)
591+
if !timeZoneId.contains("UTC") =>
592+
withInfo(expr, s"Unsupported timezone ${timeZoneId} for timestamp cast")
593+
false
594+
case _ => true
595+
}
596+
588597
val supportedCast = (child.dataType, dt) match {
589598
case (DataTypes.StringType, DataTypes.TimestampType)
590599
if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() =>
@@ -593,7 +602,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
593602
false
594603
case _ => true
595604
}
596-
if (supportedCast) {
605+
606+
if (supportedCast && supportedTimezone) {
597607
castToProto(timeZoneId, dt, childExpr, evalModeStr)
598608
} else {
599609
// no need to call withInfo here since it was called when determining

0 commit comments

Comments
 (0)