Skip to content

Commit 3d78bf4

Browse files
authored
Keep output as scalar for scalar function if all inputs are scalar (#7967)
* Keep output as scalar for scalar function if all inputs are scalar * Add end-to-end tests
1 parent 0d4dc36 commit 3d78bf4

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

datafusion/physical-expr/src/functions.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ where
357357
ColumnarValue::Array(a) => Some(a.len()),
358358
});
359359

360+
let is_scalar = len.is_none();
361+
360362
let inferred_length = len.unwrap_or(1);
361363
let args = args
362364
.iter()
@@ -373,7 +375,14 @@ where
373375
.collect::<Vec<ArrayRef>>();
374376

375377
let result = (inner)(&args);
376-
result.map(ColumnarValue::Array)
378+
379+
if is_scalar {
380+
// If all inputs are scalar, keeps output as scalar
381+
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
382+
result.map(ColumnarValue::Scalar)
383+
} else {
384+
result.map(ColumnarValue::Array)
385+
}
377386
})
378387
}
379388

datafusion/physical-expr/src/planner.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,37 @@ pub fn create_physical_expr(
448448
}
449449
}
450450
}
451+
452+
#[cfg(test)]
453+
mod tests {
454+
use super::*;
455+
use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
456+
use arrow_schema::{DataType, Field, Schema};
457+
use datafusion_common::{DFSchema, Result};
458+
use datafusion_expr::{col, left, Literal};
459+
460+
#[test]
461+
fn test_create_physical_expr_scalar_input_output() -> Result<()> {
462+
let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit()));
463+
464+
let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]);
465+
let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?;
466+
let p = create_physical_expr(&expr, &df_schema, &schema, &ExecutionProps::new())?;
467+
468+
let batch = RecordBatch::try_new(
469+
Arc::new(schema),
470+
vec![Arc::new(StringArray::from_iter_values(vec![
471+
"A", "B", "C", "D",
472+
]))],
473+
)?;
474+
let result = p.evaluate(&batch)?;
475+
let result = result.into_array(4);
476+
477+
assert_eq!(
478+
&result,
479+
&(Arc::new(BooleanArray::from(vec![true, false, false, false,])) as ArrayRef)
480+
);
481+
482+
Ok(())
483+
}
484+
}

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,3 +1878,51 @@ query T
18781878
SELECT CONCAT('Hello', 'World')
18791879
----
18801880
HelloWorld
1881+
1882+
statement ok
1883+
CREATE TABLE simple_string(
1884+
letter STRING,
1885+
letter2 STRING
1886+
) as VALUES
1887+
('A', 'APACHE'),
1888+
('B', 'APACHE'),
1889+
('C', 'APACHE'),
1890+
('D', 'APACHE')
1891+
;
1892+
1893+
query TT
1894+
EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
1895+
----
1896+
logical_plan
1897+
Projection: simple_string.letter, simple_string.letter = Utf8("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1))
1898+
--TableScan: simple_string projection=[letter]
1899+
physical_plan
1900+
ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter = left(Utf8("APACHE"),Int64(1))]
1901+
--MemoryExec: partitions=1, partition_sizes=[1]
1902+
1903+
query TB
1904+
SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
1905+
----
1906+
----
1907+
A true
1908+
B false
1909+
C false
1910+
D false
1911+
1912+
query TT
1913+
EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
1914+
----
1915+
logical_plan
1916+
Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1))
1917+
--TableScan: simple_string projection=[letter, letter2]
1918+
physical_plan
1919+
ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))]
1920+
--MemoryExec: partitions=1, partition_sizes=[1]
1921+
1922+
query TB
1923+
SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
1924+
----
1925+
A true
1926+
B false
1927+
C false
1928+
D false

0 commit comments

Comments
 (0)