Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mu
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
spark_hex, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round,
spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkBitwiseNot,
SparkDateTrunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
Expand Down Expand Up @@ -114,6 +114,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
}
"lpad" => {
let func = Arc::new(spark_lpad);
make_comet_scalar_udf!("lpad", func, without data_type)
}
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

mod read_side_padding;

pub use read_side_padding::{spark_read_side_padding, spark_rpad};
pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad};
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,23 @@ use std::sync::Arc;
const SPACE: &str = " ";
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, false)
spark_read_side_padding2(args, false, false)
}

/// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, true)
spark_read_side_padding2(args, true, false)
}

/// Custom `lpad` because DataFusion's `lpad` has differences in unicode handling
pub fn spark_lpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_read_side_padding2(args, true, true)
}

fn spark_read_side_padding2(
args: &[ColumnarValue],
truncate: bool,
is_left_pad: bool,
) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
Expand All @@ -48,12 +54,14 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
),
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
Expand All @@ -64,13 +72,15 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
)?
} else {
spark_read_side_padding_internal::<i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
Expand All @@ -91,12 +101,14 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
is_left_pad,
),
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
Expand All @@ -107,13 +119,15 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
)?
} else {
spark_read_side_padding_internal::<i64>(
dict.values(),
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
is_left_pad,
)?
};
// col consists of an array, so arg of to_array() is not used. Can be anything
Expand All @@ -122,7 +136,7 @@ fn spark_read_side_padding2(
Ok(ColumnarValue::Array(make_array(result.into())))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
"Unsupported data type {other:?} for function rpad/lpad/read_side_padding",
))),
}
}
Expand All @@ -132,15 +146,17 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
is_left_pad,
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
"Unsupported data type {other:?} for function rpad/lpad/read_side_padding",
))),
},
[ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => {
Expand All @@ -150,20 +166,22 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
is_left_pad,
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad/read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad/read_side_padding",
"Unsupported arguments {other:?} for function rpad/lpad/read_side_padding",
))),
}
}
Expand All @@ -173,6 +191,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
truncate: bool,
pad_type: ColumnarValue,
pad_string: &str,
is_left_pad: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
match pad_type {
Expand All @@ -191,6 +210,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
length.unwrap() as usize,
truncate,
pad_string,
is_left_pad,
)?),
_ => builder.append_null(),
}
Expand All @@ -212,6 +232,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
length,
truncate,
pad_string,
is_left_pad,
)?),
_ => builder.append_null(),
}
Expand All @@ -226,6 +247,7 @@ fn add_padding_string(
length: usize,
truncate: bool,
pad_string: &str,
is_left_pad: bool,
) -> Result<String, DataFusionError> {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
Expand All @@ -250,6 +272,14 @@ fn add_padding_string(
} else {
let pad_needed = length - char_len;
let pad: String = pad_string.chars().cycle().take(pad_needed).collect();
Ok(string + &pad)
let mut result = String::with_capacity(string.len() + pad.len());
if is_left_pad {
result.push_str(&pad);
result.push_str(&string);
} else {
result.push_str(&string);
result.push_str(&pad);
}
Ok(result)
}
}
2 changes: 1 addition & 1 deletion native/spark-expr/src/static_invoke/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

mod char_varchar_utils;

pub use char_varchar_utils::{spark_read_side_padding, spark_rpad};
pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad};
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[StringRepeat] -> CometStringRepeat,
classOf[StringReplace] -> CometScalarFunction("replace"),
classOf[StringRPad] -> CometStringRPad,
classOf[StringLPad] -> CometStringLPad,
classOf[StringSpace] -> CometScalarFunction("string_space"),
classOf[StringTranslate] -> CometScalarFunction("translate"),
classOf[StringTrim] -> CometScalarFunction("trim"),
Expand Down
31 changes: 30 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringRepeat, StringRPad, Substring, Upper}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, InitCap, Like, Literal, Lower, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
import org.apache.spark.sql.types.{DataTypes, LongType, StringType}

import org.apache.comet.CometConf
Expand Down Expand Up @@ -168,6 +168,35 @@ object CometStringRPad extends CometExpressionSerde[StringRPad] {
}
}

object CometStringLPad extends CometExpressionSerde[StringLPad] {

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
*
* @param expr
* The Spark expression.
* @param inputs
* The input attributes.
* @param binding
* Whether the attributes are bound (this is only relevant in aggregate expressions).
* @return
* Protocol buffer representation, or None if the expression could not be converted. In this
* case it is expected that the input expression will have been tagged with reasons why it
* could not be converted.
*/
override def convert(
expr: StringLPad,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
scalarFunctionExprToProto(
"lpad",
exprToProtoInternal(expr.str, inputs, binding),
exprToProtoInternal(expr.len, inputs, binding),
exprToProtoInternal(expr.pad, inputs, binding))
}
}

trait CommonStringExprs {

def stringDecode(
Expand Down
18 changes: 18 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("test lpad expression support") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
withParquetTable(data, "t1") {
val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
checkSparkAnswerAndOperator(res)
}
}

test("LPAD with character support other than default space") {
val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
withParquetTable(data, "t1") {
val res = sql(
""" select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), hex(lpad(unhex('aabb'), 5)),
rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
checkSparkAnswerAndOperator(res)
}
}

test("dictionary arithmetic") {
// TODO: test ANSI mode
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {
Expand Down
Loading