@@ -28,6 +28,7 @@ use crate::{Accumulator, Expr};
2828use crate :: { AccumulatorFactoryFunction , ReturnTypeFunction , Signature } ;
2929use arrow:: datatypes:: { DataType , Field } ;
3030use datafusion_common:: { exec_err, not_impl_err, Result } ;
31+ use sqlparser:: ast:: NullTreatment ;
3132use std:: any:: Any ;
3233use std:: fmt:: { self , Debug , Formatter } ;
3334use std:: sync:: Arc ;
@@ -139,8 +140,15 @@ impl AggregateUDF {
139140 ///
140141 /// This utility allows using the UDAF without requiring access to
141142 /// the registry, such as with the DataFrame API.
142- pub fn call ( & self , args : Vec < Expr > ) -> AggregateFunction {
143- AggregateFunction :: new_udf ( Arc :: new ( self . clone ( ) ) , args, false , None , None , None )
143+ pub fn call ( & self , args : Vec < Expr > ) -> Expr {
144+ Expr :: AggregateFunction ( AggregateFunction :: new_udf (
145+ Arc :: new ( self . clone ( ) ) ,
146+ args,
147+ false ,
148+ None ,
149+ None ,
150+ None ,
151+ ) )
144152 }
145153
146154 /// Returns this function's name
@@ -599,3 +607,49 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
599607 ( self . accumulator ) ( acc_args)
600608 }
601609}
610+
611+ pub trait AggregateUDFExprBuilder {
612+ fn order_by ( self , order_by : Vec < Expr > ) -> Expr ;
613+ fn filter ( self , filter : Box < Expr > ) -> Expr ;
614+ fn null_treatment ( self , null_treatment : NullTreatment ) -> Expr ;
615+ fn distinct ( self ) -> Expr ;
616+ }
617+
618+ impl AggregateUDFExprBuilder for Expr {
619+ fn order_by ( self , order_by : Vec < Expr > ) -> Expr {
620+ match self {
621+ Expr :: AggregateFunction ( mut udaf) => {
622+ udaf. order_by = Some ( order_by) ;
623+ Expr :: AggregateFunction ( udaf)
624+ }
625+ _ => self ,
626+ }
627+ }
628+ fn filter ( self , filter : Box < Expr > ) -> Expr {
629+ match self {
630+ Expr :: AggregateFunction ( mut udaf) => {
631+ udaf. filter = Some ( filter) ;
632+ Expr :: AggregateFunction ( udaf)
633+ }
634+ _ => self ,
635+ }
636+ }
637+ fn null_treatment ( self , null_treatment : NullTreatment ) -> Expr {
638+ match self {
639+ Expr :: AggregateFunction ( mut udaf) => {
640+ udaf. null_treatment = Some ( null_treatment) ;
641+ Expr :: AggregateFunction ( udaf)
642+ }
643+ _ => self ,
644+ }
645+ }
646+ fn distinct ( self ) -> Expr {
647+ match self {
648+ Expr :: AggregateFunction ( mut udaf) => {
649+ udaf. distinct = true ;
650+ Expr :: AggregateFunction ( udaf)
651+ }
652+ _ => self ,
653+ }
654+ }
655+ }
0 commit comments