@@ -545,34 +545,39 @@ impl ScalarUDFImpl for CastToI64UDF {
545545 fn return_type ( & self , _args : & [ DataType ] ) -> Result < DataType > {
546546 Ok ( DataType :: Int64 )
547547 }
548- // Wrap with Expr::Cast() to Int64
548+
549+ // Demonstrate simplifying a UDF
549550 fn simplify (
550551 & self ,
551552 mut args : Vec < Expr > ,
552553 info : & dyn SimplifyInfo ,
553554 ) -> Result < ExprSimplifyResult > {
555+ // DataFusion should have ensured the function is called with just a
556+ // single argument
557+ assert_eq ! ( args. len( ) , 1 ) ;
558+ let arg = args. pop ( ) . unwrap ( ) ;
559+
554560 // Note that Expr::cast_to requires an ExprSchema but simplify gets a
555561 // SimplifyInfo so we have to replicate some of the casting logic here.
556- let source_type = info. get_data_type ( & args[ 0 ] ) ?;
557- if source_type == DataType :: Int64 {
558- Ok ( ExprSimplifyResult :: Original ( args) )
562+
563+ let source_type = info. get_data_type ( & arg) ?;
564+ let new_expr = if source_type == DataType :: Int64 {
565+ // the argument's data type is already the correct type
566+ arg
559567 } else {
560- // DataFusion should have ensured the function is called with just a
561- // single argument
562- assert_eq ! ( args. len( ) , 1 ) ;
563- let e = args. pop ( ) . unwrap ( ) ;
564- Ok ( ExprSimplifyResult :: Simplified ( Expr :: Cast (
565- datafusion_expr:: Cast {
566- expr : Box :: new ( e) ,
567- data_type : DataType :: Int64 ,
568- } ,
569- ) ) )
570- }
568+ // need to use an actual cast to get the correct type
569+ Expr :: Cast ( datafusion_expr:: Cast {
570+ expr : Box :: new ( arg) ,
571+ data_type : DataType :: Int64 ,
572+ } )
573+ } ;
574+ // return the newly written argument to DataFusion
575+ Ok ( ExprSimplifyResult :: Simplified ( new_expr) )
571576 }
577+
572578 // Casting should be done in `simplify`, so we just return the first argument
573- fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
574- assert_eq ! ( args. len( ) , 1 ) ;
575- Ok ( args. first ( ) . unwrap ( ) . clone ( ) )
579+ fn invoke ( & self , _args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
580+ unimplemented ! ( "Function should not be evaluated" )
576581 }
577582}
578583
0 commit comments