Skip to content

Commit bfd10f7

Browse files
committed
address comment
Signed-off-by: jayzhan211 <[email protected]>
1 parent 656c6a9 commit bfd10f7

File tree

5 files changed

+60
-18
lines changed

5 files changed

+60
-18
lines changed

datafusion/common/src/scalar.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::cast::{
3030
};
3131
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
3232
use crate::hash_utils::create_hashes;
33-
use crate::utils::wrap_into_list_array;
33+
use crate::utils::array_into_list_array;
3434
use arrow::buffer::{NullBuffer, OffsetBuffer};
3535
use arrow::compute::kernels::numeric::*;
3636
use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder};
@@ -1667,7 +1667,7 @@ impl ScalarValue {
16671667
} else {
16681668
Self::iter_to_array(values.iter().cloned()).unwrap()
16691669
};
1670-
Arc::new(wrap_into_list_array(values))
1670+
Arc::new(array_into_list_array(values))
16711671
}
16721672

16731673
/// Converts a scalar value into an array of `size` rows.
@@ -2058,7 +2058,7 @@ impl ScalarValue {
20582058
let list_array = as_list_array(array);
20592059
let nested_array = list_array.value(index);
20602060
// Produces a single element `ListArray` with the value at `index`.
2061-
let arr = Arc::new(wrap_into_list_array(nested_array));
2061+
let arr = Arc::new(array_into_list_array(nested_array));
20622062

20632063
ScalarValue::List(arr)
20642064
}
@@ -2067,7 +2067,7 @@ impl ScalarValue {
20672067
let list_array = as_fixed_size_list_array(array)?;
20682068
let nested_array = list_array.value(index);
20692069
// Produces a single element `ListArray` with the value at `index`.
2070-
let arr = Arc::new(wrap_into_list_array(nested_array));
2070+
let arr = Arc::new(array_into_list_array(nested_array));
20712071

20722072
ScalarValue::List(arr)
20732073
}
@@ -3052,7 +3052,7 @@ mod tests {
30523052

30533053
let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);
30543054

3055-
let expected = wrap_into_list_array(Arc::new(StringArray::from(vec![
3055+
let expected = array_into_list_array(Arc::new(StringArray::from(vec![
30563056
"rust",
30573057
"arrow",
30583058
"data-fusion",
@@ -3091,9 +3091,9 @@ mod tests {
30913091
#[test]
30923092
fn iter_to_array_string_test() {
30933093
let arr1 =
3094-
wrap_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
3094+
array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
30953095
let arr2 =
3096-
wrap_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
3096+
array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
30973097

30983098
let scalars = vec![
30993099
ScalarValue::List(Arc::new(arr1)),
@@ -4335,13 +4335,13 @@ mod tests {
43354335
// Define list-of-structs scalars
43364336

43374337
let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap();
4338-
let nl0 = ScalarValue::List(Arc::new(wrap_into_list_array(nl0_array)));
4338+
let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array)));
43394339

43404340
let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap();
4341-
let nl1 = ScalarValue::List(Arc::new(wrap_into_list_array(nl1_array)));
4341+
let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array)));
43424342

43434343
let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap();
4344-
let nl2 = ScalarValue::List(Arc::new(wrap_into_list_array(nl2_array)));
4344+
let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array)));
43454345

43464346
// iter_to_array for list-of-struct
43474347
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();

datafusion/common/src/utils.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
//! This module provides the bisect function, which implements binary search.
1919
20+
use crate::error::_internal_err;
2021
use crate::{DataFusionError, Result, ScalarValue};
2122
use arrow::array::{ArrayRef, PrimitiveArray};
2223
use arrow::buffer::OffsetBuffer;
2324
use arrow::compute;
2425
use arrow::compute::{partition, SortColumn, SortOptions};
2526
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
2627
use arrow::record_batch::RecordBatch;
27-
use arrow_array::ListArray;
28+
use arrow_array::{Array, ListArray};
2829
use sqlparser::ast::Ident;
2930
use sqlparser::dialect::GenericDialect;
3031
use sqlparser::parser::Parser;
@@ -338,7 +339,7 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
338339

339340
/// Wrap an array into a single element `ListArray`.
340341
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
341-
pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
342+
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
342343
let offsets = OffsetBuffer::from_lengths([arr.len()]);
343344
ListArray::new(
344345
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
@@ -348,6 +349,47 @@ pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
348349
)
349350
}
350351

352+
/// Wrap arrays into a single element `ListArray`.
353+
///
354+
/// Example:
355+
/// ```
356+
/// use arrow::array::{Int32Array, ListArray, ArrayRef};
357+
/// use arrow::datatypes::{Int32Type, Field};
358+
/// use std::sync::Arc;
359+
///
360+
/// let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
361+
/// let arr2 = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
362+
///
363+
/// let list_arr = datafusion_common::utils::arrays_into_list_array([arr1, arr2]).unwrap();
364+
///
365+
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
366+
/// vec![
367+
/// Some(vec![Some(1), Some(2), Some(3)]),
368+
/// Some(vec![Some(4), Some(5), Some(6)]),
369+
/// ]
370+
/// );
371+
///
372+
/// assert_eq!(list_arr, expected);
373+
pub fn arrays_into_list_array(
374+
arr: impl IntoIterator<Item = ArrayRef>,
375+
) -> Result<ListArray> {
376+
let arr = arr.into_iter().collect::<Vec<_>>();
377+
if arr.is_empty() {
378+
return _internal_err!("Cannot wrap empty array into list array");
379+
}
380+
381+
let lens = arr.iter().map(|x| x.len()).collect::<Vec<_>>();
382+
// Assume data type is consistent
383+
let data_type = arr[0].data_type().to_owned();
384+
let values = arr.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
385+
Ok(ListArray::new(
386+
Arc::new(Field::new("item", data_type, true)),
387+
OffsetBuffer::from_lengths(lens),
388+
arrow::compute::concat(values.as_slice())?,
389+
None,
390+
))
391+
}
392+
351393
/// An extension trait for smart pointers. Provides an interface to get a
352394
/// raw pointer to the data (with metadata stripped away).
353395
///

datafusion/physical-expr/src/aggregate/array_agg.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::array::ArrayRef;
2424
use arrow::datatypes::{DataType, Field};
2525
use arrow_array::Array;
2626
use datafusion_common::cast::as_list_array;
27-
use datafusion_common::utils::wrap_into_list_array;
27+
use datafusion_common::utils::array_into_list_array;
2828
use datafusion_common::Result;
2929
use datafusion_common::ScalarValue;
3030
use datafusion_expr::Accumulator;
@@ -161,7 +161,7 @@ impl Accumulator for ArrayAggAccumulator {
161161
}
162162

163163
let concated_array = arrow::compute::concat(&element_arrays)?;
164-
let list_array = wrap_into_list_array(concated_array);
164+
let list_array = array_into_list_array(concated_array);
165165

166166
Ok(ScalarValue::List(Arc::new(list_array)))
167167
}

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ mod tests {
185185
use arrow_array::types::Int32Type;
186186
use arrow_array::{Array, ListArray};
187187
use arrow_buffer::OffsetBuffer;
188-
use datafusion_common::utils::wrap_into_list_array;
188+
use datafusion_common::utils::array_into_list_array;
189189
use datafusion_common::{internal_err, DataFusionError};
190190

191191
// arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
@@ -201,7 +201,7 @@ mod tests {
201201
};
202202

203203
let arr = arrow::compute::sort(&arr, None).unwrap();
204-
let list_arr = wrap_into_list_array(arr);
204+
let list_arr = array_into_list_array(arr);
205205
ScalarValue::List(Arc::new(list_arr))
206206
}
207207

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use arrow_buffer::NullBuffer;
2929
use datafusion_common::cast::{
3030
as_generic_string_array, as_int64_array, as_list_array, as_string_array,
3131
};
32-
use datafusion_common::utils::wrap_into_list_array;
32+
use datafusion_common::utils::array_into_list_array;
3333
use datafusion_common::{
3434
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result,
3535
};
@@ -412,7 +412,7 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
412412
// Either an empty array or all nulls:
413413
DataType::Null => {
414414
let array = new_null_array(&DataType::Null, arrays.len());
415-
Ok(Arc::new(wrap_into_list_array(array)))
415+
Ok(Arc::new(array_into_list_array(array)))
416416
}
417417
data_type => array_array(arrays, data_type),
418418
}

0 commit comments

Comments
 (0)