Skip to content

Commit 6c7d83a

Browse files
author
Arttu Voutilainen
committed
rename output columns (incl. inner struct fields) according to the given list of names
1 parent d52d064 commit 6c7d83a

File tree

2 files changed

+140
-31
lines changed

2 files changed

+140
-31
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717

1818
use async_recursion::async_recursion;
1919
use datafusion::arrow::datatypes::{
20-
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
20+
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
2121
};
2222
use datafusion::common::{
2323
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2424
};
2525

2626
use datafusion::execution::FunctionRegistry;
2727
use datafusion::logical_expr::{
28-
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
29-
LogicalPlan, Operator, ScalarUDF, Values,
28+
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
29+
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
30+
Values,
3031
};
3132
use datafusion::logical_expr::{
32-
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
33+
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
3334
Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
3435
};
3536
use datafusion::prelude::JoinType;
@@ -212,6 +213,7 @@ pub async fn from_substrait_plan(
212213
None => not_impl_err!("Cannot parse empty extension"),
213214
})
214215
.collect::<Result<HashMap<_, _>>>()?;
216+
215217
// Parse relations
216218
match plan.relations.len() {
217219
1 => {
@@ -221,7 +223,29 @@ pub async fn from_substrait_plan(
221223
Ok(from_substrait_rel(ctx, rel, &function_extension).await?)
222224
},
223225
plan_rel::RelType::Root(root) => {
224-
Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?)
226+
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?;
227+
if root.names.is_empty() {
228+
// Backwards compatibility for plans missing names
229+
return Ok(plan);
230+
}
231+
let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
232+
if renamed_schema.equivalent_names_and_types(plan.schema()) {
233+
// Nothing to do if the schema is already equivalent
234+
return Ok(plan);
235+
}
236+
237+
match plan {
238+
// If the last node of the plan produces expressions, bake the renames into those expressions.
239+
// This isn't necessary for correctness, but helps with roundtrip tests.
240+
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
241+
LogicalPlan::Aggregate(a) => {
242+
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
243+
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
244+
},
245+
// There are probably more plans where we could bake things in, can add them later as needed.
246+
// Otherwise, add a new Project to handle the renaming.
247+
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
248+
}
225249
}
226250
},
227251
None => plan_err!("Cannot parse plan relation: None")
@@ -234,6 +258,105 @@ pub async fn from_substrait_plan(
234258
}
235259
}
236260

261+
fn rename_expressions(
262+
exprs: impl IntoIterator<Item = Expr>,
263+
input_schema: &DFSchema,
264+
new_schema: DFSchemaRef,
265+
) -> Result<Vec<Expr>> {
266+
exprs
267+
.into_iter()
268+
.zip(new_schema.fields())
269+
.map(|(old_expr, new_field)| {
270+
if &old_expr.get_type(input_schema)? == new_field.data_type() {
271+
// Alias column if needed
272+
old_expr.alias_if_changed(new_field.name().into())
273+
} else {
274+
// Use Cast to rename inner struct fields + alias column if needed
275+
Expr::Cast(Cast::new(
276+
Box::new(old_expr),
277+
new_field.data_type().to_owned(),
278+
))
279+
.alias_if_changed(new_field.name().into())
280+
}
281+
})
282+
.collect()
283+
}
284+
285+
fn make_renamed_schema(
286+
schema: &DFSchemaRef,
287+
dfs_names: &Vec<String>,
288+
) -> Result<DFSchemaRef> {
289+
fn rename_inner_fields(
290+
dtype: &DataType,
291+
dfs_names: &Vec<String>,
292+
name_idx: &mut usize,
293+
) -> Result<DataType> {
294+
match dtype {
295+
DataType::Struct(fields) => {
296+
let fields = fields
297+
.iter()
298+
.map(|f| {
299+
let name = next_struct_field_name(0, dfs_names, name_idx)?;
300+
Ok((**f).to_owned().with_name(name).with_data_type(
301+
rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
302+
))
303+
})
304+
.collect::<Result<_>>()?;
305+
Ok(DataType::Struct(fields))
306+
}
307+
DataType::List(inner) => Ok(DataType::List(FieldRef::new(
308+
(**inner).to_owned().with_data_type(rename_inner_fields(
309+
inner.data_type(),
310+
dfs_names,
311+
name_idx,
312+
)?),
313+
))),
314+
DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
315+
(**inner).to_owned().with_data_type(rename_inner_fields(
316+
inner.data_type(),
317+
dfs_names,
318+
name_idx,
319+
)?),
320+
))),
321+
_ => Ok(dtype.to_owned()),
322+
}
323+
}
324+
325+
let mut name_idx = 0;
326+
327+
let (qualifiers, fields): (_, Vec<Field>) = schema
328+
.iter()
329+
.map(|(q, f)| {
330+
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
331+
Ok((
332+
q.cloned(),
333+
(**f)
334+
.to_owned()
335+
.with_name(name)
336+
.with_data_type(rename_inner_fields(
337+
f.data_type(),
338+
dfs_names,
339+
&mut name_idx,
340+
)?),
341+
))
342+
})
343+
.collect::<Result<Vec<_>>>()?
344+
.into_iter()
345+
.unzip();
346+
347+
if name_idx != dfs_names.len() {
348+
return substrait_err!(
349+
"Names list must match exactly to nested schema, but found {} uses for {} names",
350+
name_idx,
351+
dfs_names.len());
352+
}
353+
354+
Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
355+
qualifiers,
356+
&Arc::new(Schema::new(fields)),
357+
)?))
358+
}
359+
237360
/// Convert Substrait Rel to DataFusion DataFrame
238361
#[async_recursion]
239362
pub async fn from_substrait_rel(

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,9 @@ async fn implicit_cast() -> Result<()> {
372372
async fn aggregate_case() -> Result<()> {
373373
assert_expected_plan(
374374
"SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
375-
"Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\
375+
"Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\
376376
\n TableScan: data projection=[a]",
377-
false // NULL vs Int64(NULL)
377+
true
378378
)
379379
.await
380380
}
@@ -594,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> {
594594

595595
#[tokio::test]
596596
async fn simple_intersect() -> Result<()> {
597+
// Substrait treats both COUNT(*) and COUNT(1) the same
597598
assert_expected_plan(
598599
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
599-
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
600+
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\
600601
\n Projection: \
601602
\n LeftSemi Join: data.a = data2.a\
602603
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
603604
\n TableScan: data projection=[a]\
604605
\n TableScan: data2 projection=[a]",
605-
false // COUNT(*) vs COUNT(Int64(1))
606+
true
606607
)
607608
.await
608609
}
609610

610611
#[tokio::test]
611612
async fn simple_intersect_table_reuse() -> Result<()> {
612-
assert_expected_plan(
613-
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);",
614-
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
615-
\n Projection: \
616-
\n LeftSemi Join: data.a = data.a\
617-
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
618-
\n TableScan: data projection=[a]\
619-
\n TableScan: data projection=[a]",
620-
false // COUNT(*) vs COUNT(Int64(1))
621-
)
622-
.await
613+
roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
623614
}
624615

625616
#[tokio::test]
@@ -699,20 +690,14 @@ async fn all_type_literal() -> Result<()> {
699690

700691
#[tokio::test]
701692
async fn roundtrip_literal_list() -> Result<()> {
702-
assert_expected_plan(
703-
"SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
704-
"Projection: List([[1, 2, 3], [], , []])\
705-
\n TableScan: data projection=[]",
706-
false, // "List(..)" vs "make_array(..)"
707-
)
708-
.await
693+
roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await
709694
}
710695

711696
#[tokio::test]
712697
async fn roundtrip_literal_struct() -> Result<()> {
713698
assert_expected_plan(
714699
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
715-
"Projection: Struct({c0:1,c1:true,c2:})\
700+
"Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\
716701
\n TableScan: data projection=[]",
717702
false, // "Struct(..)" vs "struct(..)"
718703
)
@@ -985,12 +970,13 @@ async fn assert_expected_plan(
985970

986971
println!("{proto:?}");
987972

988-
let plan2str = format!("{plan2:?}");
989-
assert_eq!(expected_plan_str, &plan2str);
990-
991973
if assert_schema {
992974
assert_eq!(plan.schema(), plan2.schema());
993975
}
976+
977+
let plan2str = format!("{plan2:?}");
978+
assert_eq!(expected_plan_str, &plan2str);
979+
994980
Ok(())
995981
}
996982

0 commit comments

Comments
 (0)