Skip to content
756 changes: 753 additions & 3 deletions datafusion/expr-common/src/signature.rs

Large diffs are not rendered by default.

285 changes: 285 additions & 0 deletions datafusion/expr/src/arguments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Argument resolution logic for named function parameters

use crate::Expr;
use datafusion_common::{plan_err, Result};
use std::collections::HashMap;

/// Resolves function arguments, handling named and positional notation.
///
/// This function validates and reorders arguments to match the function's parameter names
/// when named arguments are used.
///
/// # Rules
/// - All positional arguments must come before named arguments
/// - Named arguments can be in any order after positional arguments
/// - Parameter names follow SQL identifier rules: unquoted names are case-insensitive
/// (normalized to lowercase), quoted names are case-sensitive
/// - No duplicate parameter names allowed
///
/// # Arguments
/// * `param_names` - The function's parameter names in order
/// * `args` - The argument expressions
/// * `arg_names` - Optional parameter name for each argument
///
/// # Returns
/// A vector of expressions in the correct order matching the parameter names
///
/// # Examples
/// ```text
/// Given parameters ["a", "b", "c"]
/// And call: func(10, c => 30, b => 20)
/// Returns: [Expr(10), Expr(20), Expr(30)]
/// ```
pub fn resolve_function_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
if args.len() != arg_names.len() {
return plan_err!(
"Internal error: args length ({}) != arg_names length ({})",
args.len(),
arg_names.len()
);
}

// Check if all arguments are positional (fast path)
if arg_names.iter().all(|name| name.is_none()) {
return Ok(args);
}

validate_argument_order(&arg_names)?;

reorder_named_arguments(param_names, args, arg_names)
}

/// Validates that positional arguments come before named arguments
fn validate_argument_order(arg_names: &[Option<String>]) -> Result<()> {
let mut seen_named = false;
for (i, arg_name) in arg_names.iter().enumerate() {
match arg_name {
Some(_) => seen_named = true,
None if seen_named => {
return plan_err!(
"Positional argument at position {} follows named argument. \
All positional arguments must come before named arguments.",
i
);
}
None => {}
}
}
Ok(())
}

/// Reorders arguments based on named parameters to match signature order
fn reorder_named_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
// Build HashMap for O(1) parameter name lookups
let param_index_map: HashMap<&str, usize> = param_names
.iter()
.enumerate()
.map(|(idx, name)| (name.as_str(), idx))
.collect();

let positional_count = arg_names.iter().filter(|n| n.is_none()).count();

// Capture args length before consuming the vector
let args_len = args.len();

let expected_arg_count = param_names.len();

if positional_count > expected_arg_count {
return plan_err!(
"Too many positional arguments: expected at most {}, got {}",
expected_arg_count,
positional_count
);
}

let mut result: Vec<Option<Expr>> = vec![None; expected_arg_count];

for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() {
if let Some(name) = arg_name {
// Named argument - O(1) lookup in HashMap
let param_index =
param_index_map.get(name.as_str()).copied().ok_or_else(|| {
datafusion_common::plan_datafusion_err!(
"Unknown parameter name '{}'. Valid parameters are: [{}]",
name,
param_names.join(", ")
)
})?;

if result[param_index].is_some() {
return plan_err!("Parameter '{}' specified multiple times", name);
}

result[param_index] = Some(arg);
} else {
result[i] = Some(arg);
}
}

// Only require parameters up to the number of arguments provided (supports optional parameters)
let required_count = args_len;
for i in 0..required_count {
if result[i].is_none() {
return plan_err!("Missing required parameter '{}'", param_names[i]);
}
}

// Return only the assigned parameters (handles optional trailing parameters)
Ok(result.into_iter().take(required_count).flatten().collect())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::lit;

#[test]
fn test_all_positional() {
let param_names = vec!["a".to_string(), "b".to_string()];

let args = vec![lit(1), lit("hello")];
let arg_names = vec![None, None];

let result =
resolve_function_arguments(&param_names, args.clone(), arg_names).unwrap();
assert_eq!(result.len(), 2);
}

#[test]
fn test_all_named() {
let param_names = vec!["a".to_string(), "b".to_string()];

let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();
assert_eq!(result.len(), 2);
}

#[test]
fn test_named_reordering() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(c => 3.0, a => 1, b => "hello")
let args = vec![lit(3.0), lit(1), lit("hello")];
let arg_names = vec![
Some("c".to_string()),
Some("a".to_string()),
Some("b".to_string()),
];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();

// Should be reordered to [a, b, c] = [1, "hello", 3.0]
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}

#[test]
fn test_mixed_positional_and_named() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(1, c => 3.0, b => "hello")
let args = vec![lit(1), lit(3.0), lit("hello")];
let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();

// Should be reordered to [a, b, c] = [1, "hello", 3.0]
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}

#[test]
fn test_positional_after_named_error() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(a => 1, "hello") - ERROR
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), None];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Positional argument"));
}

#[test]
fn test_unknown_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(x => 1, b => "hello") - ERROR
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("x".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unknown parameter"));
}

#[test]
fn test_duplicate_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(a => 1, a => 2) - ERROR
let args = vec![lit(1), lit(2)];
let arg_names = vec![Some("a".to_string()), Some("a".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("specified multiple times"));
}

#[test]
fn test_missing_required_parameter() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(a => 1, c => 3.0) - missing 'b'
let args = vec![lit(1), lit(3.0)];
let arg_names = vec![Some("a".to_string()), Some("c".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Missing required parameter"));
}
}
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ mod udaf;
mod udf;
mod udwf;

pub mod arguments;
pub mod conditional_expressions;
pub mod execution_props;
pub mod expr;
Expand Down
51 changes: 50 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ pub fn generate_signature_error_msg(
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr()
.to_string_repr_with_names(func_signature.parameter_names.as_deref())
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
Expand Down Expand Up @@ -1295,6 +1295,7 @@ mod tests {
Cast, ExprFunctionExt, WindowFunctionDefinition,
};
use arrow::datatypes::{UnionFields, UnionMode};
use datafusion_expr_common::signature::{TypeSignature, Volatility};

#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
Expand Down Expand Up @@ -1714,4 +1715,52 @@ mod tests {
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
assert!(!can_hash(&list_union_type));
}

#[test]
fn test_generate_signature_error_msg_with_parameter_names() {
let sig = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
TypeSignature::Exact(vec![
DataType::Utf8,
DataType::Int64,
DataType::Int64,
]),
],
Volatility::Immutable,
)
.with_parameter_names(vec![
"str".to_string(),
"start_pos".to_string(),
"length".to_string(),
])
.expect("valid parameter names");

// Generate error message with only 1 argument provided
let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]);

assert!(
error_msg.contains("str: Utf8, start_pos: Int64"),
"Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
);
assert!(
error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
"Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
);
}

#[test]
fn test_generate_signature_error_msg_without_parameter_names() {
let sig = Signature::one_of(
vec![TypeSignature::Any(2), TypeSignature::Any(3)],
Volatility::Immutable,
);

let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]);

assert!(
error_msg.contains("Any, Any"),
"Expected 'Any, Any' without parameter names, got: {error_msg}"
);
}
}
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl ArrayReplace {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace")],
}
Expand Down Expand Up @@ -186,6 +187,7 @@ impl ArrayReplaceN {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace_n")],
}
Expand Down Expand Up @@ -265,6 +267,7 @@ impl ArrayReplaceAll {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace_all")],
}
Expand Down
Loading