@@ -21,7 +21,7 @@ use std::cmp::Ordering;
2121use  std:: collections:: { HashMap ,  HashSet } ; 
2222use  std:: fmt:: { self ,  Debug ,  Display ,  Formatter } ; 
2323use  std:: hash:: { Hash ,  Hasher } ; 
24- use  std:: sync:: Arc ; 
24+ use  std:: sync:: { Arc ,   OnceLock } ; 
2525
2626use  super :: dml:: CopyTo ; 
2727use  super :: DdlStatement ; 
@@ -2965,6 +2965,15 @@ impl Aggregate {
29652965                . into_iter ( ) 
29662966                . map ( |( q,  f) | ( q,  f. as_ref ( ) . clone ( ) . with_nullable ( true ) . into ( ) ) ) 
29672967                . collect :: < Vec < _ > > ( ) ; 
2968+             qualified_fields. push ( ( 
2969+                 None , 
2970+                 Field :: new ( 
2971+                     Self :: INTERNAL_GROUPING_ID , 
2972+                     Self :: grouping_id_type ( qualified_fields. len ( ) ) , 
2973+                     false , 
2974+                 ) 
2975+                 . into ( ) , 
2976+             ) ) ; 
29682977        } 
29692978
29702979        qualified_fields. extend ( exprlist_to_fields ( aggr_expr. as_slice ( ) ,  & input) ?) ; 
@@ -3016,9 +3025,19 @@ impl Aggregate {
30163025        } ) 
30173026    } 
30183027
3028+     fn  is_grouping_set ( & self )  -> bool  { 
3029+         matches ! ( self . group_expr. as_slice( ) ,  [ Expr :: GroupingSet ( _) ] ) 
3030+     } 
3031+ 
30193032    /// Get the output expressions. 
30203033fn  output_expressions ( & self )  -> Result < Vec < & Expr > >  { 
3034+         static  INTERNAL_ID_EXPR :  OnceLock < Expr >  = OnceLock :: new ( ) ; 
30213035        let  mut  exprs = grouping_set_to_exprlist ( self . group_expr . as_slice ( ) ) ?; 
3036+         if  self . is_grouping_set ( )  { 
3037+             exprs. push ( INTERNAL_ID_EXPR . get_or_init ( || { 
3038+                 Expr :: Column ( Column :: from_name ( Self :: INTERNAL_GROUPING_ID ) ) 
3039+             } ) ) ; 
3040+         } 
30223041        exprs. extend ( self . aggr_expr . iter ( ) ) ; 
30233042        debug_assert ! ( exprs. len( )  == self . schema. fields( ) . len( ) ) ; 
30243043        Ok ( exprs) 
@@ -3030,6 +3049,41 @@ impl Aggregate {
30303049pub  fn  group_expr_len ( & self )  -> Result < usize >  { 
30313050        grouping_set_expr_count ( & self . group_expr ) 
30323051    } 
3052+ 
3053+     /// Returns the data type of the grouping id. 
3054+ /// The grouping ID value is a bitmask where each set bit 
3055+ /// indicates that the corresponding grouping expression is 
3056+ /// null 
3057+ pub  fn  grouping_id_type ( group_exprs :  usize )  -> DataType  { 
3058+         if  group_exprs <= 8  { 
3059+             DataType :: UInt8 
3060+         }  else  if  group_exprs <= 16  { 
3061+             DataType :: UInt16 
3062+         }  else  if  group_exprs <= 32  { 
3063+             DataType :: UInt32 
3064+         }  else  { 
3065+             DataType :: UInt64 
3066+         } 
3067+     } 
3068+ 
3069+     /// Internal column used when the aggregation is a grouping set. 
3070+ /// 
3071+ /// This column contains a bitmask where each bit represents a grouping 
3072+ /// expression. The least significant bit corresponds to the rightmost 
3073+ /// grouping expression. A bit value of 0 indicates that the corresponding 
3074+ /// column is included in the grouping set, while a value of 1 means it is excluded. 
3075+ /// 
3076+ /// For example, for the grouping expressions CUBE(a, b), the grouping ID 
3077+ /// column will have the following values: 
3078+ ///     0b00: Both `a` and `b` are included 
3079+ ///     0b01: `b` is excluded 
3080+ ///     0b10: `a` is excluded 
3081+ ///     0b11: Both `a` and `b` are excluded 
3082+ /// 
3083+ /// This internal column is necessary because excluded columns are replaced 
3084+ /// with `NULL` values. To handle these cases correctly, we must distinguish 
3085+ /// between an actual `NULL` value in a column and a column being excluded from the set. 
3086+ pub  const  INTERNAL_GROUPING_ID :  & ' static  str  = "__grouping_id" ; 
30333087} 
30343088
30353089// Manual implementation needed because of `schema` field. Comparison excludes this field. 
0 commit comments