diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7560e0c2d5..45859c5fb2 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -17,7 +17,7 @@ use std::{ any::Any, - fmt::{Display, Formatter}, + fmt::{Debug, Display, Formatter}, hash::{Hash, Hasher}, sync::Arc, }; @@ -31,7 +31,8 @@ use arrow::{ }; use arrow_array::{ types::{Int16Type, Int32Type, Int64Type, Int8Type}, - Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, }; use arrow_schema::{DataType, Schema}; use chrono::{TimeZone, Timelike}; @@ -107,6 +108,74 @@ macro_rules! cast_utf8_to_timestamp { }}; } +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + + fn cast( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the + // result is expressed without scientific notation with at least one digit on either side of + // the decimal point. Otherwise, Spark uses a mantissa followed by E and an + // exponent. The mantissa has an optional leading minus sign followed by one digit to the + // left of the decimal point, and the minimal number of digits greater than zero to the + // right. The exponent has and optional leading minus sign. + // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html + + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), + Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + + if formatted.contains(".") { + Ok(Some(formatted)) + } else { + // `formatted` is already in scientific notation and can be split up by E + // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 + let prepare_number: Vec<&str> = formatted.split("E").collect(); + + let coefficient = prepare_number[0]; + + let exponent = prepare_number[1]; + + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::, CometError>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + impl Cast { pub fn new( child: Arc, @@ -185,6 +254,18 @@ impl Cast { ), } } + (DataType::Float64, DataType::Utf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float64, DataType::LargeUtf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float32, DataType::Utf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + } + (DataType::Float32, DataType::LargeUtf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + } _ => { // when we have no Spark-specific casting we delegate to DataFusion cast_with_options(&array, to_type, &CAST_OPTIONS)? @@ -248,6 +329,26 @@ impl Cast { Ok(cast_array) } + fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) + } + + fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> CometResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) + } + fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index a31f4e6822..3be7dcb64f 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -329,9 +329,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateFloats(), DataTypes.createDecimalType(10, 2)) } - ignore("cast FloatType to StringType") { + test("cast FloatType to StringType") { // https://github.com/apache/datafusion-comet/issues/312 - castTest(generateFloats(), DataTypes.StringType) + val r = new Random(0) + val values = Seq( + Float.MaxValue, + Float.MinValue, + Float.NaN, + Float.PositiveInfinity, + Float.NegativeInfinity, + 1.0f, + -1.0f, + Short.MinValue.toFloat, + Short.MaxValue.toFloat, + 0.0f) ++ + Range(0, dataSize).map(_ => r.nextFloat()) + withNulls(values).toDF("a") } ignore("cast FloatType to TimestampType") { @@ -374,9 +387,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDoubles(), DataTypes.createDecimalType(10, 2)) } - ignore("cast DoubleType to StringType") { + test("cast DoubleType to StringType") { // https://github.com/apache/datafusion-comet/issues/312 - castTest(generateDoubles(), DataTypes.StringType) + val r = new Random(0) + val values = Seq( + Double.MaxValue, + Double.MinValue, + Double.NaN, + Double.PositiveInfinity, + Double.NegativeInfinity, + 0.0d) ++ + Range(0, dataSize).map(_ => r.nextDouble()) + withNulls(values).toDF("a") } ignore("cast DoubleType to TimestampType") {