@@ -2644,10 +2644,10 @@ fn create_case_expr(
26442644
26452645#[ cfg( test) ]
26462646mod tests {
2647- use futures :: { poll , StreamExt } ;
2648- use std :: { sync :: Arc , task :: Poll } ;
2649-
2650- use arrow :: array :: { Array , DictionaryArray , Int32Array , RecordBatch , StringArray } ;
2647+ use crate :: execution :: { operators :: InputBatch , planner :: PhysicalPlanner } ;
2648+ use arrow :: array :: {
2649+ Array , DictionaryArray , Int32Array , Int32Builder , RecordBatch , StringArray ,
2650+ } ;
26512651 use arrow:: datatypes:: { DataType , Field , Fields , Schema } ;
26522652 use datafusion:: catalog:: memory:: DataSourceExec ;
26532653 use datafusion:: datasource:: listing:: PartitionedFile ;
@@ -2656,17 +2656,24 @@ mod tests {
26562656 FileGroup , FileScanConfigBuilder , FileSource , ParquetSource ,
26572657 } ;
26582658 use datafusion:: error:: DataFusionError ;
2659+ use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
26592660 use datafusion:: logical_expr:: ScalarUDF ;
26602661 use datafusion:: physical_plan:: ExecutionPlan ;
2662+ use datafusion:: prelude:: SessionConfig ;
26612663 use datafusion:: { assert_batches_eq, physical_plan:: common:: collect, prelude:: SessionContext } ;
2664+ use futures:: { poll, StreamExt } ;
2665+ use std:: hash:: Hasher ;
2666+ use std:: time:: { Duration , Instant } ;
2667+ use std:: { sync:: Arc , task:: Poll } ;
26622668 use tempfile:: TempDir ;
2669+ use tokio:: runtime:: Runtime ;
26632670 use tokio:: sync:: mpsc;
2664-
2665- use crate :: execution:: { operators:: InputBatch , planner:: PhysicalPlanner } ;
2671+ use tokio:: time:: timeout;
26662672
26672673 use crate :: execution:: operators:: ExecutionError ;
26682674 use crate :: parquet:: parquet_support:: SparkParquetOptions ;
26692675 use crate :: parquet:: schema_adapter:: SparkSchemaAdapterFactory ;
2676+ use datafusion_comet_proto:: spark_expression:: data_type:: { ListInfo , StructInfo } ;
26702677 use datafusion_comet_proto:: spark_expression:: expr:: ExprStruct ;
26712678 use datafusion_comet_proto:: {
26722679 spark_expression:: expr:: ExprStruct :: * ,
@@ -3448,4 +3455,294 @@ mod tests {
34483455 assert_batches_eq ! ( expected, & [ actual] ) ;
34493456 Ok ( ( ) )
34503457 }
3458+
3459+ #[ test]
3460+ fn test_map_grouping_aggregation_using_spark_plan ( ) {
3461+ use arrow:: array:: { Int64Builder , StringBuilder } ;
3462+ use arrow:: datatypes:: { Field , Schema } ;
3463+ use datafusion:: prelude:: SessionContext ;
3464+ use datafusion_comet_proto:: spark_expression:: {
3465+ agg_expr:: ExprStruct as AggExprStruct ,
3466+ data_type:: { DataTypeInfo , MapInfo } ,
3467+ expr:: ExprStruct ,
3468+ literal:: Value ,
3469+ AggExpr , BoundReference , Count , DataType as SparkDataType , Expr , Literal ,
3470+ } ;
3471+ use datafusion_comet_proto:: spark_operator:: {
3472+ operator:: OpStruct , AggregateMode , HashAggregate , Operator ,
3473+ } ;
3474+ use std:: sync:: Arc ;
3475+ eprintln ! ( "Testing Map grouping support using Spark plan..." ) ;
3476+
3477+ // Create map array using Arrow's MapBuilder
3478+ let string_builder = StringBuilder :: new ( ) ;
3479+ let int_builder = Int64Builder :: new ( ) ;
3480+ let mut map_builder = arrow:: array:: MapBuilder :: new ( None , string_builder, int_builder) ;
3481+
3482+ // Add some map entries to test grouping:
3483+ // Row 0: {"a": 1, b: 2}
3484+ map_builder. keys ( ) . append_value ( "a" ) ;
3485+ map_builder. values ( ) . append_value ( 1 ) ;
3486+ map_builder. keys ( ) . append_value ( "b" ) ;
3487+ map_builder. values ( ) . append_value ( 2 ) ;
3488+ map_builder. append ( true ) . unwrap ( ) ;
3489+
3490+ // Row 1: {"b": 2, a: 1}
3491+ map_builder. keys ( ) . append_value ( "a" ) ;
3492+ map_builder. values ( ) . append_value ( 1 ) ;
3493+ map_builder. keys ( ) . append_value ( "b" ) ;
3494+ map_builder. values ( ) . append_value ( 2 ) ;
3495+ map_builder. append ( true ) . unwrap ( ) ;
3496+
3497+ // Row 2: {"a": 1} - same as row 0, should group together
3498+ map_builder. keys ( ) . append_value ( "a" ) ;
3499+ map_builder. values ( ) . append_value ( 1 ) ;
3500+ map_builder. append ( true ) . unwrap ( ) ;
3501+
3502+ // Row 3: {"c": 3}
3503+ map_builder. keys ( ) . append_value ( "c" ) ;
3504+ map_builder. values ( ) . append_value ( 3 ) ;
3505+ map_builder. append ( true ) . unwrap ( ) ;
3506+
3507+ let map_array = map_builder. finish ( ) ;
3508+
3509+ // Create schema - need a scan operator first
3510+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
3511+ "map_col" ,
3512+ map_array. data_type( ) . clone( ) ,
3513+ false ,
3514+ ) ] ) ) ;
3515+
3516+ // Create Spark protobuf structures for HashAgg
3517+ let map_type = SparkDataType {
3518+ type_id : 15 , // MAP type ID (correct enum value)
3519+ type_info : Some ( Box :: new ( DataTypeInfo {
3520+ datatype_struct : Some (
3521+ spark_expression:: data_type:: data_type_info:: DatatypeStruct :: Map ( Box :: new (
3522+ MapInfo {
3523+ key_type : Some ( Box :: new ( SparkDataType {
3524+ type_id : 7 , // STRING type ID (correct enum value)
3525+ type_info : None ,
3526+ } ) ) ,
3527+ value_type : Some ( Box :: new ( SparkDataType {
3528+ type_id : 3 , // INT32 type ID (correct enum value)
3529+ type_info : None ,
3530+ } ) ) ,
3531+ value_contains_null : true ,
3532+ } ,
3533+ ) ) ,
3534+ ) ,
3535+ } ) ) ,
3536+ } ;
3537+
3538+ let key_dt = SparkDataType {
3539+ type_id : 7 , // STRING in your enum
3540+ type_info : None ,
3541+ } ;
3542+ let value_dt = SparkDataType {
3543+ type_id : 4 , // INT64 in your enum
3544+ type_info : None ,
3545+ } ;
3546+ let value_contains_null: bool = true ; // from MapInfo.value_contains_null
3547+
3548+ // 1) Struct<key, value>
3549+ let entries_struct = SparkDataType {
3550+ type_id : 16 , // STRUCT
3551+ type_info : Some ( Box :: new ( DataTypeInfo {
3552+ datatype_struct : Some (
3553+ spark_expression:: data_type:: data_type_info:: DatatypeStruct :: Struct (
3554+ StructInfo {
3555+ field_names : vec ! [ "keys" . to_string( ) , "values" . to_string( ) ] ,
3556+ field_datatypes : vec ! [
3557+ key_dt. clone( ) , // key type K
3558+ value_dt. clone( ) , // value type V
3559+ ] ,
3560+ field_nullable : vec ! [
3561+ false , // key is non-null (Arrow/Spark rule)
3562+ value_contains_null, // value nullability copied from Map
3563+ ] ,
3564+ } ,
3565+ ) ,
3566+ ) ,
3567+ } ) ) ,
3568+ } ;
3569+
3570+ // 2) Array<List> around the entries struct: contains_null = false
3571+ let list_of_entries = SparkDataType {
3572+ type_id : 14 , // ARRAY
3573+ type_info : Some ( Box :: new ( DataTypeInfo {
3574+ datatype_struct : Some (
3575+ spark_expression:: data_type:: data_type_info:: DatatypeStruct :: List ( Box :: new (
3576+ ListInfo {
3577+ element_type : Some ( Box :: new ( entries_struct) ) ,
3578+ contains_null : false , // entries are never null
3579+ } ,
3580+ ) ) ,
3581+ ) ,
3582+ } ) ) ,
3583+ } ;
3584+
3585+ // Create BoundReference for map column (index 0)
3586+ let bound_ref = Expr {
3587+ expr_struct : Some ( Bound ( BoundReference {
3588+ index : 0 ,
3589+ datatype : Some ( map_type. clone ( ) ) ,
3590+ } ) ) ,
3591+ } ;
3592+
3593+ // Create Count aggregation expression
3594+ let count_agg = AggExpr {
3595+ expr_struct : Some ( AggExprStruct :: Count ( Count {
3596+ children : vec ! [ Expr {
3597+ expr_struct: Some ( ExprStruct :: Literal ( Literal {
3598+ datatype: Some ( SparkDataType {
3599+ type_id: 3 , // INT32 type ID
3600+ type_info: None ,
3601+ } ) ,
3602+ is_null: false ,
3603+ value: Some ( Value :: IntVal ( 1 ) ) ,
3604+ } ) ) ,
3605+ } ] ,
3606+ } ) ) ,
3607+ } ;
3608+
3609+ let map_sort_expr = Expr {
3610+ expr_struct : Some ( ScalarFunc ( spark_expression:: ScalarFunc {
3611+ func : "map_sort" . to_string ( ) ,
3612+ args : vec ! [ bound_ref] ,
3613+ return_type : Some ( map_type. clone ( ) ) ,
3614+ } ) ) ,
3615+ } ;
3616+
3617+ let map_canonicalize_expr = Expr {
3618+ expr_struct : Some ( ScalarFunc ( spark_expression:: ScalarFunc {
3619+ func : "map_canonicalize" . to_string ( ) ,
3620+ args : vec ! [ map_sort_expr] ,
3621+ return_type : Some ( list_of_entries. clone ( ) ) ,
3622+ } ) ) ,
3623+ } ;
3624+
3625+ println ! ( "Map canonicalize expression: {:?}" , map_canonicalize_expr) ;
3626+
3627+ // Create HashAggregate protobuf structure
3628+ let hash_agg = HashAggregate {
3629+ grouping_exprs : vec ! [ map_canonicalize_expr] ,
3630+ agg_exprs : vec ! [ count_agg] ,
3631+ result_exprs : vec ! [ ] ,
3632+ mode : AggregateMode :: Partial as i32 ,
3633+ } ;
3634+
3635+ // Create a Scan child operator
3636+ let child_op = Operator {
3637+ plan_id : 1 ,
3638+ op_struct : Some ( OpStruct :: Scan ( spark_operator:: Scan {
3639+ fields : vec ! [ map_type] ,
3640+ source : "test_scan" . to_string ( ) ,
3641+ } ) ) ,
3642+ children : vec ! [ ] ,
3643+ } ;
3644+
3645+ // Create the HashAgg operator
3646+ let hash_agg_op = Operator {
3647+ plan_id : 2 ,
3648+ op_struct : Some ( OpStruct :: HashAgg ( hash_agg) ) ,
3649+ children : vec ! [ child_op] ,
3650+ } ;
3651+
3652+ // Build context / planner
3653+ let ctx = Arc :: new ( SessionContext :: new ( ) ) ;
3654+ let planner = PhysicalPlanner :: new ( ctx. clone ( ) , 0 ) ;
3655+
3656+ // Create plan
3657+ let mut inputs = Vec :: new ( ) ;
3658+ let ( mut scans, exec_plan) = planner. create_plan ( & hash_agg_op, & mut inputs, 1 ) . unwrap ( ) ;
3659+
3660+ // Build input batch and set on scan
3661+ let batch = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( map_array) ] ) . unwrap ( ) ;
3662+ if let Some ( scan) = scans. get_mut ( 0 ) {
3663+ use crate :: execution:: operators:: InputBatch ;
3664+ let columns: Vec < Arc < dyn Array > > = batch. columns ( ) . iter ( ) . cloned ( ) . collect ( ) ;
3665+ let input_batch = InputBatch :: new ( columns, Some ( batch. num_rows ( ) ) ) ;
3666+ scan. set_input_batch ( input_batch) ;
3667+ }
3668+
3669+ // Execute plan
3670+ let task_ctx = ctx. task_ctx ( ) ;
3671+ let mut stream = exec_plan. native_plan . execute ( 0 , task_ctx) . unwrap ( ) ;
3672+
3673+ // Create a runtime to drive non-blocking polls and draining
3674+ let runtime = Runtime :: new ( ) . unwrap ( ) ;
3675+
3676+ // Force one poll to consume the data batch from ScanStream
3677+ runtime. block_on ( async {
3678+ let _ = poll ! ( stream. next( ) ) ;
3679+ } ) ;
3680+
3681+ // Signal EOF to finalize aggregation
3682+ if let Some ( scan) = scans. get_mut ( 0 ) {
3683+ use crate :: execution:: operators:: InputBatch ;
3684+ scan. set_input_batch ( InputBatch :: EOF ) ;
3685+ }
3686+
3687+ // Drain stream to completion and collect results
3688+ let result_batches = runtime. block_on ( async move {
3689+ let mut collected = Vec :: new ( ) ;
3690+ while let Some ( item) = stream. next ( ) . await {
3691+ match item {
3692+ Ok ( batch) => {
3693+ println ! ( "Received batch: {:?}" , batch) ;
3694+ collected. push ( batch) ;
3695+ }
3696+ Err ( e) => panic ! ( "stream error: {e:?}" ) ,
3697+ }
3698+ }
3699+ collected
3700+ } ) ;
3701+
3702+ assert ! (
3703+ !result_batches. is_empty( ) ,
3704+ "Should have at least one result batch"
3705+ ) ;
3706+
3707+ // Verify expected groups and counts directly
3708+ use arrow:: array:: { Int64Array , ListArray , StringArray , StructArray } ;
3709+ use std:: collections:: HashMap ;
3710+
3711+ // Build actual map: "k=v,k=v" -> count
3712+ let mut actual: HashMap < String , i64 > = HashMap :: new ( ) ;
3713+ for rb in & result_batches {
3714+ let groups = rb. column ( 0 ) . as_any ( ) . downcast_ref :: < ListArray > ( ) . unwrap ( ) ;
3715+ let counts = rb. column ( 1 ) . as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
3716+
3717+ for row in 0 ..rb. num_rows ( ) {
3718+ // Extract key/value pairs for this group and canonicalize
3719+ let sub = groups. value ( row) ;
3720+ let sa = sub. as_any ( ) . downcast_ref :: < StructArray > ( ) . unwrap ( ) ;
3721+ let keys = sa. column ( 0 ) . as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
3722+ let vals = sa. column ( 1 ) . as_any ( ) . downcast_ref :: < Int64Array > ( ) . unwrap ( ) ;
3723+
3724+ let mut pairs: Vec < ( String , i64 ) > = ( 0 ..sa. len ( ) )
3725+ . filter ( |& i| sa. is_valid ( i) )
3726+ . map ( |i| ( keys. value ( i) . to_string ( ) , vals. value ( i) ) )
3727+ . collect ( ) ;
3728+ pairs. sort_by ( |a, b| a. 0 . cmp ( & b. 0 ) . then ( a. 1 . cmp ( & b. 1 ) ) ) ;
3729+ let key = pairs
3730+ . into_iter ( )
3731+ . map ( |( k, v) | format ! ( "{k}={v}" ) )
3732+ . collect :: < Vec < _ > > ( )
3733+ . join ( "," ) ;
3734+
3735+ let cnt = counts. value ( row) ;
3736+ actual. insert ( key, cnt) ;
3737+ }
3738+ }
3739+
3740+ // Expected groups
3741+ let mut expected: HashMap < String , i64 > = HashMap :: new ( ) ;
3742+ expected. insert ( "a=1,b=2" . to_string ( ) , 2 ) ;
3743+ expected. insert ( "a=1" . to_string ( ) , 1 ) ;
3744+ expected. insert ( "c=3" . to_string ( ) , 1 ) ;
3745+
3746+ assert_eq ! ( actual, expected) ;
3747+ }
34513748}
0 commit comments