1717
1818use async_recursion:: async_recursion;
1919use datafusion:: arrow:: datatypes:: {
20- DataType , Field , Fields , IntervalUnit , Schema , TimeUnit ,
20+ DataType , Field , FieldRef , Fields , IntervalUnit , Schema , TimeUnit ,
2121} ;
2222use datafusion:: common:: {
2323 not_impl_err, substrait_datafusion_err, substrait_err, DFSchema , DFSchemaRef ,
2424} ;
2525
2626use datafusion:: execution:: FunctionRegistry ;
2727use datafusion:: logical_expr:: {
28- aggregate_function, expr:: find_df_window_func, BinaryExpr , Case , EmptyRelation , Expr ,
29- LogicalPlan , Operator , ScalarUDF , Values ,
28+ aggregate_function, expr:: find_df_window_func, Aggregate , BinaryExpr , Case ,
29+ EmptyRelation , Expr , ExprSchemable , LogicalPlan , Operator , Projection , ScalarUDF ,
30+ Values ,
3031} ;
3132use datafusion:: logical_expr:: {
32- expr, Cast , Extension , GroupingSet , Like , LogicalPlanBuilder , Partitioning ,
33+ col , expr, Cast , Extension , GroupingSet , Like , LogicalPlanBuilder , Partitioning ,
3334 Repartition , Subquery , WindowFrameBound , WindowFrameUnits ,
3435} ;
3536use datafusion:: prelude:: JoinType ;
@@ -212,6 +213,7 @@ pub async fn from_substrait_plan(
212213 None => not_impl_err ! ( "Cannot parse empty extension" ) ,
213214 } )
214215 . collect :: < Result < HashMap < _ , _ > > > ( ) ?;
216+
215217 // Parse relations
216218 match plan. relations . len ( ) {
217219 1 => {
@@ -221,7 +223,29 @@ pub async fn from_substrait_plan(
221223 Ok ( from_substrait_rel ( ctx, rel, & function_extension) . await ?)
222224 } ,
223225 plan_rel:: RelType :: Root ( root) => {
224- Ok ( from_substrait_rel ( ctx, root. input . as_ref ( ) . unwrap ( ) , & function_extension) . await ?)
226+ let plan = from_substrait_rel ( ctx, root. input . as_ref ( ) . unwrap ( ) , & function_extension) . await ?;
227+ if root. names . is_empty ( ) {
228+ // Backwards compatibility for plans missing names
229+ return Ok ( plan) ;
230+ }
231+ let renamed_schema = make_renamed_schema ( plan. schema ( ) , & root. names ) ?;
232+ if renamed_schema. equivalent_names_and_types ( plan. schema ( ) ) {
233+ // Nothing to do if the schema is already equivalent
234+ return Ok ( plan) ;
235+ }
236+
237+ match plan {
238+ // If the last node of the plan produces expressions, bake the renames into those expressions.
239+ // This isn't necessary for correctness, but helps with roundtrip tests.
240+ LogicalPlan :: Projection ( p) => Ok ( LogicalPlan :: Projection ( Projection :: try_new ( rename_expressions ( p. expr , p. input . schema ( ) , renamed_schema) ?, p. input ) ?) ) ,
241+ LogicalPlan :: Aggregate ( a) => {
242+ let new_aggr_exprs = rename_expressions ( a. aggr_expr , a. input . schema ( ) , renamed_schema) ?;
243+ Ok ( LogicalPlan :: Aggregate ( Aggregate :: try_new ( a. input , a. group_expr , new_aggr_exprs) ?) )
244+ } ,
245+ // There are probably more plans where we could bake things in, can add them later as needed.
246+ // Otherwise, add a new Project to handle the renaming.
247+ _ => Ok ( LogicalPlan :: Projection ( Projection :: try_new ( rename_expressions ( plan. schema ( ) . columns ( ) . iter ( ) . map ( |c| col ( c. to_owned ( ) ) ) , plan. schema ( ) , renamed_schema) ?, Arc :: new ( plan) ) ?) )
248+ }
225249 }
226250 } ,
227251 None => plan_err ! ( "Cannot parse plan relation: None" )
@@ -234,6 +258,105 @@ pub async fn from_substrait_plan(
234258 }
235259}
236260
261+ fn rename_expressions (
262+ exprs : impl IntoIterator < Item = Expr > ,
263+ input_schema : & DFSchema ,
264+ new_schema : DFSchemaRef ,
265+ ) -> Result < Vec < Expr > > {
266+ exprs
267+ . into_iter ( )
268+ . zip ( new_schema. fields ( ) )
269+ . map ( |( old_expr, new_field) | {
270+ if & old_expr. get_type ( input_schema) ? == new_field. data_type ( ) {
271+ // Alias column if needed
272+ old_expr. alias_if_changed ( new_field. name ( ) . into ( ) )
273+ } else {
274+ // Use Cast to rename inner struct fields + alias column if needed
275+ Expr :: Cast ( Cast :: new (
276+ Box :: new ( old_expr) ,
277+ new_field. data_type ( ) . to_owned ( ) ,
278+ ) )
279+ . alias_if_changed ( new_field. name ( ) . into ( ) )
280+ }
281+ } )
282+ . collect ( )
283+ }
284+
285+ fn make_renamed_schema (
286+ schema : & DFSchemaRef ,
287+ dfs_names : & Vec < String > ,
288+ ) -> Result < DFSchemaRef > {
289+ fn rename_inner_fields (
290+ dtype : & DataType ,
291+ dfs_names : & Vec < String > ,
292+ name_idx : & mut usize ,
293+ ) -> Result < DataType > {
294+ match dtype {
295+ DataType :: Struct ( fields) => {
296+ let fields = fields
297+ . iter ( )
298+ . map ( |f| {
299+ let name = next_struct_field_name ( 0 , dfs_names, name_idx) ?;
300+ Ok ( ( * * f) . to_owned ( ) . with_name ( name) . with_data_type (
301+ rename_inner_fields ( f. data_type ( ) , dfs_names, name_idx) ?,
302+ ) )
303+ } )
304+ . collect :: < Result < _ > > ( ) ?;
305+ Ok ( DataType :: Struct ( fields) )
306+ }
307+ DataType :: List ( inner) => Ok ( DataType :: List ( FieldRef :: new (
308+ ( * * inner) . to_owned ( ) . with_data_type ( rename_inner_fields (
309+ inner. data_type ( ) ,
310+ dfs_names,
311+ name_idx,
312+ ) ?) ,
313+ ) ) ) ,
314+ DataType :: LargeList ( inner) => Ok ( DataType :: LargeList ( FieldRef :: new (
315+ ( * * inner) . to_owned ( ) . with_data_type ( rename_inner_fields (
316+ inner. data_type ( ) ,
317+ dfs_names,
318+ name_idx,
319+ ) ?) ,
320+ ) ) ) ,
321+ _ => Ok ( dtype. to_owned ( ) ) ,
322+ }
323+ }
324+
325+ let mut name_idx = 0 ;
326+
327+ let ( qualifiers, fields) : ( _ , Vec < Field > ) = schema
328+ . iter ( )
329+ . map ( |( q, f) | {
330+ let name = next_struct_field_name ( 0 , dfs_names, & mut name_idx) ?;
331+ Ok ( (
332+ q. cloned ( ) ,
333+ ( * * f)
334+ . to_owned ( )
335+ . with_name ( name)
336+ . with_data_type ( rename_inner_fields (
337+ f. data_type ( ) ,
338+ dfs_names,
339+ & mut name_idx,
340+ ) ?) ,
341+ ) )
342+ } )
343+ . collect :: < Result < Vec < _ > > > ( ) ?
344+ . into_iter ( )
345+ . unzip ( ) ;
346+
347+ if name_idx != dfs_names. len ( ) {
348+ return substrait_err ! (
349+ "Names list must match exactly to nested schema, but found {} uses for {} names" ,
350+ name_idx,
351+ dfs_names. len( ) ) ;
352+ }
353+
354+ Ok ( Arc :: new ( DFSchema :: from_field_specific_qualified_schema (
355+ qualifiers,
356+ & Arc :: new ( Schema :: new ( fields) ) ,
357+ ) ?) )
358+ }
359+
237360/// Convert Substrait Rel to DataFusion DataFrame
238361#[ async_recursion]
239362pub async fn from_substrait_rel (
0 commit comments