Skip to content

Commit a586a57

Browse files
committed
implement_comet_native_lpad_expr
1 parent ceb9efd commit a586a57

File tree

2 files changed

+34
-132
lines changed

2 files changed

+34
-132
lines changed

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 34 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
use arrow::array::builder::GenericStringBuilder;
1919
use arrow::array::cast::as_dictionary_array;
2020
use arrow::array::types::Int32Type;
21-
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
21+
use arrow::array::{make_array, Array, DictionaryArray};
2222
use arrow::array::{ArrayRef, OffsetSizeTrait};
2323
use arrow::datatypes::DataType;
24-
use datafusion::common::{cast::as_generic_string_array, DataFusionError, HashMap, ScalarValue};
24+
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
2525
use datafusion::physical_plan::ColumnarValue;
2626
use std::fmt::Write;
2727
use std::sync::Arc;
@@ -42,48 +42,18 @@ fn spark_read_side_padding2(
4242
) -> Result<ColumnarValue, DataFusionError> {
4343
match args {
4444
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
45-
let rpad_arg = RPadArgument::ConstLength(*length);
4645
match array.data_type() {
47-
DataType::Utf8 => {
48-
spark_read_side_padding_internal::<i32>(array, truncate, rpad_arg)
49-
}
50-
DataType::LargeUtf8 => {
51-
spark_read_side_padding_internal::<i64>(array, truncate, rpad_arg)
52-
}
53-
// Dictionary support required for SPARK-48498
54-
DataType::Dictionary(_, value_type) => {
55-
let dict = as_dictionary_array::<Int32Type>(array);
56-
let col = if value_type.as_ref() == &DataType::Utf8 {
57-
spark_read_side_padding_internal::<i32>(dict.values(), truncate, rpad_arg)?
58-
} else {
59-
spark_read_side_padding_internal::<i64>(dict.values(), truncate, rpad_arg)?
60-
};
61-
// col consists of an array, so arg of to_array() is not used. Can be anything
62-
let values = col.to_array(0)?;
63-
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
64-
Ok(ColumnarValue::Array(make_array(result.into())))
65-
}
66-
other => Err(DataFusionError::Internal(format!(
67-
"Unsupported data type {other:?} for function rpad/read_side_padding",
68-
))),
69-
}
70-
}
71-
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => {
72-
let rpad_arg = RPadArgument::ColArray(Arc::clone(array_int));
73-
match array.data_type() {
74-
DataType::Utf8 => {
75-
spark_read_side_padding_internal::<i32>(array, truncate, rpad_arg)
76-
}
46+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length, truncate),
7747
DataType::LargeUtf8 => {
78-
spark_read_side_padding_internal::<i64>(array, truncate, rpad_arg)
48+
spark_read_side_padding_internal::<i64>(array, *length, truncate)
7949
}
8050
// Dictionary support required for SPARK-48498
8151
DataType::Dictionary(_, value_type) => {
8252
let dict = as_dictionary_array::<Int32Type>(array);
8353
let col = if value_type.as_ref() == &DataType::Utf8 {
84-
spark_read_side_padding_internal::<i32>(dict.values(), truncate, rpad_arg)?
54+
spark_read_side_padding_internal::<i32>(dict.values(), *length, truncate)?
8555
} else {
86-
spark_read_side_padding_internal::<i64>(dict.values(), truncate, rpad_arg)?
56+
spark_read_side_padding_internal::<i64>(dict.values(), *length, truncate)?
8757
};
8858
// col consists of an array, so arg of to_array() is not used. Can be anything
8959
let values = col.to_array(0)?;
@@ -101,101 +71,44 @@ fn spark_read_side_padding2(
10171
}
10272
}
10373

104-
enum RPadArgument {
105-
ConstLength(i32),
106-
ColArray(ArrayRef),
107-
}
108-
10974
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
11075
array: &ArrayRef,
76+
length: i32,
11177
truncate: bool,
112-
rpad_argument: RPadArgument,
11378
) -> Result<ColumnarValue, DataFusionError> {
11479
let string_array = as_generic_string_array::<T>(array)?;
115-
match rpad_argument {
116-
RPadArgument::ColArray(array_int) => {
117-
let int_pad_array = array_int.as_primitive::<Int32Type>();
118-
let mut str_pad_value_map = HashMap::new();
119-
for i in 0..string_array.len() {
120-
if string_array.is_null(i) || int_pad_array.is_null(i) {
121-
continue; // skip nulls
122-
}
123-
str_pad_value_map.insert(string_array.value(i), int_pad_array.value(i));
124-
}
125-
126-
let mut builder = GenericStringBuilder::<T>::with_capacity(
127-
str_pad_value_map.len(),
128-
str_pad_value_map.len() * int_pad_array.len(),
129-
);
130-
131-
for string in string_array.iter() {
132-
match string {
133-
Some(string) => {
134-
// It looks Spark's UTF8String is closer to chars rather than graphemes
135-
// https://stackoverflow.com/a/46290728
136-
let char_len = string.chars().count();
137-
let length: usize = 0.max(*str_pad_value_map.get(string).unwrap()) as usize;
138-
let space_string = " ".repeat(length);
139-
if length <= char_len {
140-
if truncate {
141-
let idx = string
142-
.char_indices()
143-
.nth(length)
144-
.map(|(i, _)| i)
145-
.unwrap_or(string.len());
146-
builder.append_value(&string[..idx]);
147-
} else {
148-
builder.append_value(string);
149-
}
150-
} else {
151-
// write_str updates only the value buffer, not null nor offset buffer
152-
// This is convenient for concatenating str(s)
153-
builder.write_str(string)?;
154-
builder.append_value(&space_string[char_len..]);
155-
}
156-
}
157-
_ => builder.append_null(),
158-
}
159-
}
160-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
161-
}
162-
RPadArgument::ConstLength(length) => {
163-
let length = 0.max(length) as usize;
164-
let space_string = " ".repeat(length);
80+
let length = 0.max(length) as usize;
81+
let space_string = " ".repeat(length);
16582

166-
let mut builder = GenericStringBuilder::<T>::with_capacity(
167-
string_array.len(),
168-
string_array.len() * length,
169-
);
83+
let mut builder =
84+
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
17085

171-
for string in string_array.iter() {
172-
match string {
173-
Some(string) => {
174-
// It looks Spark's UTF8String is closer to chars rather than graphemes
175-
// https://stackoverflow.com/a/46290728
176-
let char_len = string.chars().count();
177-
if length <= char_len {
178-
if truncate {
179-
let idx = string
180-
.char_indices()
181-
.nth(length)
182-
.map(|(i, _)| i)
183-
.unwrap_or(string.len());
184-
builder.append_value(&string[..idx]);
185-
} else {
186-
builder.append_value(string);
187-
}
188-
} else {
189-
// write_str updates only the value buffer, not null nor offset buffer
190-
// This is convenient for concatenating str(s)
191-
builder.write_str(string)?;
192-
builder.append_value(&space_string[char_len..]);
193-
}
86+
for string in string_array.iter() {
87+
match string {
88+
Some(string) => {
89+
// It looks Spark's UTF8String is closer to chars rather than graphemes
90+
// https://stackoverflow.com/a/46290728
91+
let char_len = string.chars().count();
92+
if length <= char_len {
93+
if truncate {
94+
let idx = string
95+
.char_indices()
96+
.nth(length)
97+
.map(|(i, _)| i)
98+
.unwrap_or(string.len());
99+
builder.append_value(&string[..idx]);
100+
} else {
101+
builder.append_value(string);
194102
}
195-
_ => builder.append_null(),
103+
} else {
104+
// write_str updates only the value buffer, not null nor offset buffer
105+
// This is convenient for concatenating str(s)
106+
builder.write_str(string)?;
107+
builder.append_value(&space_string[char_len..]);
196108
}
197109
}
198-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
110+
_ => builder.append_null(),
199111
}
200112
}
113+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
201114
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
323323
}
324324
}
325325

326-
test("fix_rpad") {
327-
withTable("t1") {
328-
sql("create table t1(c1 varchar(100), c2 int) using parquet")
329-
sql("insert into t1 values('IfIWasARoadIWouldBeBent', 10)")
330-
sql("insert into t1 values('IfIWereATrainIwouldBeLate', 9)")
331-
sql("insert into t1 values(NULL, 10)")
332-
val res = sql("select rpad(c1,c2) , rpad(c1,5) from t1 order by c1")
333-
checkSparkAnswerAndOperator(res)
334-
}
335-
}
336-
337326
test("test_lpad_expression") {
338327
withTable("t1") {
339328
sql("create table t1(c1 varchar(100), c2 int) using parquet")

0 commit comments

Comments
 (0)