Skip to content
Merged
6 changes: 3 additions & 3 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,9 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType {
/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
1 + list_ndims(field.data_type())
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()),
_ => 0,
}
}
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ use arrow::datatypes::DataType;
/// return results with this timezone.
pub const TIMEZONE_WILDCARD: &str = "+TZ";

/// Constant that is used as a placeholder for any valid fixed size list.
/// This is used where a function can accept a fixed size list type with any
/// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths.
pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN;

///A function's volatility, which defines the functions eligibility for certain optimizations
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Volatility {
Expand Down
110 changes: 107 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::signature::{ArrayFunctionSignature, TIMEZONE_WILDCARD};
use std::sync::Arc;

use crate::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};
use crate::{Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
Expand Down Expand Up @@ -379,13 +383,28 @@ fn coerced_from<'a>(
List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),

// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
List(_) | LargeList(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}
// should be able to coerce wildcard fixed size list to non wildcard fixed size list
FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD) => match type_from {
FixedSizeList(f_from, size_from) => {
match coerced_from(f_into.data_type(), f_from.data_type()) {
Some(data_type) if &data_type != f_into.data_type() => {
let new_field =
Arc::new(f_into.as_ref().clone().with_data_type(data_type));
Some(FixedSizeList(new_field, *size_from))
}
Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)),
_ => None,
}
}
_ => None,
},

Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Expand Down Expand Up @@ -415,8 +434,12 @@ fn coerced_from<'a>(

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::Volatility;

use super::*;
use arrow::datatypes::{DataType, TimeUnit};
use arrow::datatypes::{DataType, Field, TimeUnit};

#[test]
fn test_maybe_data_types() {
Expand Down Expand Up @@ -492,4 +515,85 @@ mod tests {

Ok(())
}

#[test]
fn test_fixed_list_wildcard_coerce() -> Result<()> {
let inner = Arc::new(Field::new("item", DataType::Int32, false));
let current_types = vec![
DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size
];

let signature = Signature::exact(
vec![DataType::FixedSizeList(
inner.clone(),
FIXED_SIZE_LIST_WILDCARD,
)],
Volatility::Stable,
);

let coerced_data_types = data_types(&current_types, &signature).unwrap();
assert_eq!(coerced_data_types, current_types);

// make sure it can't coerce to a different size
let signature = Signature::exact(
vec![DataType::FixedSizeList(inner.clone(), 3)],
Volatility::Stable,
);
let coerced_data_types = data_types(&current_types, &signature);
assert!(coerced_data_types.is_err());
Comment on lines +537 to +543
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good assert, though what i had in mind is that it probably wouldn't work for current_types = vec![DataType::FixedSizeList(inner.clone(), 3)], even though in that case it should, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so like making sure the same type works?

        let current_types = vec![
            DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size
        ];

        // make sure it works with the same type.
        let signature = Signature::exact(
            vec![DataType::FixedSizeList(inner.clone(), 2)],
            Volatility::Stable,
        );
        let coerced_data_types = data_types(&current_types, &signature).unwrap();
        assert_eq!(coerced_data_types, current_types);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.


// make sure it works with the same type.
let signature = Signature::exact(
vec![DataType::FixedSizeList(inner.clone(), 2)],
Volatility::Stable,
);
let coerced_data_types = data_types(&current_types, &signature).unwrap();
assert_eq!(coerced_data_types, current_types);

Ok(())
}

#[test]
fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
let type_into = DataType::FixedSizeList(
Arc::new(Field::new(
"item",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Int32, false)),
FIXED_SIZE_LIST_WILDCARD,
),
false,
)),
FIXED_SIZE_LIST_WILDCARD,
);

let type_from = DataType::FixedSizeList(
Arc::new(Field::new(
"item",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Int8, false)),
4,
),
false,
)),
3,
);

assert_eq!(
coerced_from(&type_into, &type_from),
Some(DataType::FixedSizeList(
Arc::new(Field::new(
"item",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Int32, false)),
4,
),
false,
)),
3,
))
);

Ok(())
}
}