diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4b863927e6..393f57662e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,8 +20,8 @@ use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mu use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, - spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, + spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, + spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace, }; use arrow::datatypes::DataType; @@ -114,6 +114,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_rpad); make_comet_scalar_udf!("rpad", func, without data_type) } + "lpad" => { + let func = Arc::new(spark_lpad); + make_comet_scalar_udf!("lpad", func, without data_type) + } "round" => { make_comet_scalar_udf!("round", spark_round, data_type) } diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs index 0a8d8f3c55..5bb94a7ad5 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -17,4 +17,4 @@ mod read_side_padding; -pub use read_side_padding::{spark_read_side_padding, spark_rpad}; +pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad}; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs index 166bb6ddf9..d969b6279b 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -28,17 +28,23 @@ use std::sync::Arc; const SPACE: &str = " "; /// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { - spark_read_side_padding2(args, false) + spark_read_side_padding2(args, false, false) } /// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling pub fn spark_rpad(args: &[ColumnarValue]) -> Result { - spark_read_side_padding2(args, true) + spark_read_side_padding2(args, true, false) +} + +/// Custom `lpad` because DataFusion's `lpad` has differences in unicode handling +pub fn spark_lpad(args: &[ColumnarValue]) -> Result { + spark_read_side_padding2(args, true, true) } fn spark_read_side_padding2( args: &[ColumnarValue], truncate: bool, + is_left_pad: bool, ) -> Result { match args { [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { @@ -48,12 +54,14 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, ), DataType::LargeUtf8 => spark_read_side_padding_internal::( array, truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, ), // Dictionary support required for SPARK-48498 DataType::Dictionary(_, value_type) => { @@ -64,6 +72,7 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, )? } else { spark_read_side_padding_internal::( @@ -71,6 +80,7 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, )? }; // col consists of an array, so arg of to_array() is not used. Can be anything @@ -91,12 +101,14 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), string, + is_left_pad, ), DataType::LargeUtf8 => spark_read_side_padding_internal::( array, truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), string, + is_left_pad, ), // Dictionary support required for SPARK-48498 DataType::Dictionary(_, value_type) => { @@ -107,6 +119,7 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, )? } else { spark_read_side_padding_internal::( @@ -114,6 +127,7 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))), SPACE, + is_left_pad, )? }; // col consists of an array, so arg of to_array() is not used. Can be anything @@ -122,7 +136,7 @@ fn spark_read_side_padding2( Ok(ColumnarValue::Array(make_array(result.into()))) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", + "Unsupported data type {other:?} for function rpad/lpad/read_side_padding", ))), } } @@ -132,15 +146,17 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Array(Arc::::clone(array_int)), SPACE, + is_left_pad, ), DataType::LargeUtf8 => spark_read_side_padding_internal::( array, truncate, ColumnarValue::Array(Arc::::clone(array_int)), SPACE, + is_left_pad, ), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function rpad/read_side_padding", + "Unsupported data type {other:?} for function rpad/lpad/read_side_padding", ))), }, [ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => { @@ -150,12 +166,14 @@ fn spark_read_side_padding2( truncate, ColumnarValue::Array(Arc::::clone(array_int)), string, + is_left_pad, ), DataType::LargeUtf8 => spark_read_side_padding_internal::( array, truncate, ColumnarValue::Array(Arc::::clone(array_int)), string, + is_left_pad, ), other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function rpad/read_side_padding", @@ -163,7 +181,7 @@ fn spark_read_side_padding2( } } other => Err(DataFusionError::Internal(format!( - "Unsupported arguments {other:?} for function rpad/read_side_padding", + "Unsupported arguments {other:?} for function rpad/lpad/read_side_padding", ))), } } @@ -173,6 +191,7 @@ fn spark_read_side_padding_internal( truncate: bool, pad_type: ColumnarValue, pad_string: &str, + is_left_pad: bool, ) -> Result { let string_array = as_generic_string_array::(array)?; match pad_type { @@ -191,6 +210,7 @@ fn spark_read_side_padding_internal( length.unwrap() as usize, truncate, pad_string, + is_left_pad, )?), _ => builder.append_null(), } @@ -212,6 +232,7 @@ fn spark_read_side_padding_internal( length, truncate, pad_string, + is_left_pad, )?), _ => builder.append_null(), } @@ -226,6 +247,7 @@ fn add_padding_string( length: usize, truncate: bool, pad_string: &str, + is_left_pad: bool, ) -> Result { // It looks Spark's UTF8String is closer to chars rather than graphemes // https://stackoverflow.com/a/46290728 @@ -250,6 +272,14 @@ fn add_padding_string( } else { let pad_needed = length - char_len; let pad: String = pad_string.chars().cycle().take(pad_needed).collect(); - Ok(string + &pad) + let mut result = String::with_capacity(string.len() + pad.len()); + if is_left_pad { + result.push_str(&pad); + result.push_str(&string); + } else { + result.push_str(&string); + result.push_str(&pad); + } + Ok(result) } } diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 39735f1569..6a2176b5f9 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -17,4 +17,4 @@ mod char_varchar_utils; -pub use char_varchar_utils::{spark_read_side_padding, spark_rpad}; +pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 892d8bca63..bb05015c20 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -175,6 +175,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[StringRepeat] -> CometStringRepeat, classOf[StringReplace] -> CometScalarFunction("replace"), classOf[StringRPad] -> CometStringRPad, + classOf[StringLPad] -> CometStringLPad, classOf[StringSpace] -> CometScalarFunction("string_space"), classOf[StringTranslate] -> CometScalarFunction("translate"), classOf[StringTrim] -> CometScalarFunction("trim"), diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 36df9ed1c2..9c85d8d6ca 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringRepeat, StringRPad, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper} import org.apache.spark.sql.types.{DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -168,6 +168,35 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] { } } +object CometStringLPad extends CometExpressionSerde[StringLPad] { + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param expr + * The Spark expression. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + override def convert( + expr: StringLPad, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + scalarFunctionExprToProto( + "lpad", + exprToProtoInternal(expr.str, inputs, binding), + exprToProtoInternal(expr.len, inputs, binding), + exprToProtoInternal(expr.pad, inputs, binding)) + } +} + trait CommonStringExprs { def stringDecode( diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 07663ea91f..f391d52f78 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -431,6 +431,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("test lpad expression support") { + val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2)) + withParquetTable(data, "t1") { + val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1") + checkSparkAnswerAndOperator(res) + } + } + + test("LPAD with character support other than default space") { + val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2)) + withParquetTable(data, "t1") { + val res = sql( + """ select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)), + rpad(_1, 5, '??') from t1 order by _1 """.stripMargin) + checkSparkAnswerAndOperator(res) + } + } + test("dictionary arithmetic") { // TODO: test ANSI mode withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {