Skip to content

Commit a9a0c2f

Browse files
committed
change to accept multi args
Signed-off-by: jayzhan211 <[email protected]>
1 parent 0d4dc36 commit a9a0c2f

File tree

5 files changed

+36
-29
lines changed

5 files changed

+36
-29
lines changed

datafusion/common/src/scalar.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,12 +1840,12 @@ impl ScalarValue {
18401840
let arr = Decimal128Array::from(vals)
18411841
.with_precision_and_scale(*precision, *scale)
18421842
.unwrap();
1843-
wrap_into_list_array(Arc::new(arr))
1843+
wrap_into_list_array(&[&arr]).unwrap()
18441844
}
18451845

18461846
DataType::Null => {
18471847
let arr = new_null_array(&DataType::Null, values.len());
1848-
wrap_into_list_array(arr)
1848+
wrap_into_list_array(&[&arr]).unwrap()
18491849
}
18501850
_ => panic!(
18511851
"Unsupported data type {:?} for ScalarValue::list_to_array",
@@ -2242,18 +2242,14 @@ impl ScalarValue {
22422242
let list_array = as_list_array(array);
22432243
let nested_array = list_array.value(index);
22442244
// Produces a single element `ListArray` with the value at `index`.
2245-
let arr = Arc::new(wrap_into_list_array(nested_array));
2246-
2247-
ScalarValue::List(arr)
2245+
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nested_array])?))
22482246
}
22492247
// TODO: There is no test for FixedSizeList now, add it later
22502248
DataType::FixedSizeList(_, _) => {
22512249
let list_array = as_fixed_size_list_array(array)?;
22522250
let nested_array = list_array.value(index);
22532251
// Produces a single element `ListArray` with the value at `index`.
2254-
let arr = Arc::new(wrap_into_list_array(nested_array));
2255-
2256-
ScalarValue::List(arr)
2252+
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nested_array])?))
22572253
}
22582254
DataType::Date32 => {
22592255
typed_cast!(array, index, Date32Array, Date32)
@@ -3236,11 +3232,12 @@ mod tests {
32363232

32373233
let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);
32383234

3239-
let expected = wrap_into_list_array(Arc::new(StringArray::from(vec![
3235+
let expected = wrap_into_list_array(&[&StringArray::from(vec![
32403236
"rust",
32413237
"arrow",
32423238
"data-fusion",
3243-
])));
3239+
])])
3240+
.unwrap();
32443241
let result = as_list_array(&array);
32453242
assert_eq!(result, &expected);
32463243
}
@@ -3274,10 +3271,10 @@ mod tests {
32743271

32753272
#[test]
32763273
fn iter_to_array_string_test() {
3277-
let arr1 =
3278-
wrap_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
3274+
let arr1 = wrap_into_list_array(&[&StringArray::from(vec!["foo", "bar", "baz"])])
3275+
.unwrap();
32793276
let arr2 =
3280-
wrap_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
3277+
wrap_into_list_array(&[&StringArray::from(vec!["rust", "world"])]).unwrap();
32813278

32823279
let scalars = vec![
32833280
ScalarValue::List(Arc::new(arr1)),
@@ -4519,13 +4516,16 @@ mod tests {
45194516
// Define list-of-structs scalars
45204517

45214518
let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap();
4522-
let nl0 = ScalarValue::List(Arc::new(wrap_into_list_array(nl0_array)));
4519+
let nl0 =
4520+
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl0_array]).unwrap()));
45234521

45244522
let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap();
4525-
let nl1 = ScalarValue::List(Arc::new(wrap_into_list_array(nl1_array)));
4523+
let nl1 =
4524+
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl1_array]).unwrap()));
45264525

45274526
let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap();
4528-
let nl2 = ScalarValue::List(Arc::new(wrap_into_list_array(nl2_array)));
4527+
let nl2 =
4528+
ScalarValue::List(Arc::new(wrap_into_list_array(&[&nl2_array]).unwrap()));
45294529

45304530
// iter_to_array for list-of-struct
45314531
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();

datafusion/common/src/utils.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::compute;
2424
use arrow::compute::{partition, SortColumn, SortOptions};
2525
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
2626
use arrow::record_batch::RecordBatch;
27-
use arrow_array::ListArray;
27+
use arrow_array::{Array, ListArray};
2828
use sqlparser::ast::Ident;
2929
use sqlparser::dialect::GenericDialect;
3030
use sqlparser::parser::Parser;
@@ -336,16 +336,24 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
336336
count
337337
}
338338

339-
/// Wrap an array into a single element `ListArray`.
339+
/// Wrap arrays into a single element `ListArray`.
340340
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
341-
pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
342-
let offsets = OffsetBuffer::from_lengths([arr.len()]);
343-
ListArray::new(
344-
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
345-
offsets,
346-
arr,
341+
pub fn wrap_into_list_array(arr: &[&dyn Array]) -> Result<ListArray> {
342+
if arr.is_empty() {
343+
return Err(DataFusionError::Internal(
344+
"Cannot wrap empty array into list array".to_owned(),
345+
));
346+
}
347+
348+
let lens = arr.iter().map(|x| x.len()).collect::<Vec<_>>();
349+
// Assume data type is consistent
350+
let data_type = arr[0].data_type().to_owned();
351+
Ok(ListArray::new(
352+
Arc::new(Field::new("item", data_type, true)),
353+
OffsetBuffer::from_lengths(lens),
354+
arrow::compute::concat(arr)?,
347355
None,
348-
)
356+
))
349357
}
350358

351359
/// An extension trait for smart pointers. Provides an interface to get a

datafusion/physical-expr/src/aggregate/array_agg.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ impl Accumulator for ArrayAggAccumulator {
161161
}
162162

163163
let concated_array = arrow::compute::concat(&element_arrays)?;
164-
let list_array = wrap_into_list_array(concated_array);
165-
164+
let list_array = wrap_into_list_array(&[&concated_array])?;
166165
Ok(ScalarValue::List(Arc::new(list_array)))
167166
}
168167

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ mod tests {
201201
};
202202

203203
let arr = arrow::compute::sort(&arr, None).unwrap();
204-
let list_arr = wrap_into_list_array(arr);
204+
let list_arr = wrap_into_list_array(&[&arr]).unwrap();
205205
ScalarValue::List(Arc::new(list_arr))
206206
}
207207

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
412412
// Either an empty array or all nulls:
413413
DataType::Null => {
414414
let array = new_null_array(&DataType::Null, arrays.len());
415-
Ok(Arc::new(wrap_into_list_array(array)))
415+
Ok(Arc::new(wrap_into_list_array(&[&array])?))
416416
}
417417
data_type => array_array(arrays, data_type),
418418
}

0 commit comments

Comments
 (0)