Skip to content

Commit 444cf47

Browse files
committed
Introduced map canonicalization logic and added a spark-test plan in native layer
1 parent 314c191 commit 444cf47

File tree

4 files changed

+385
-7
lines changed

4 files changed

+385
-7
lines changed

native/core/src/execution/planner.rs

Lines changed: 303 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,10 +2644,10 @@ fn create_case_expr(
26442644

26452645
#[cfg(test)]
26462646
mod tests {
2647-
use futures::{poll, StreamExt};
2648-
use std::{sync::Arc, task::Poll};
2649-
2650-
use arrow::array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
2647+
use crate::execution::{operators::InputBatch, planner::PhysicalPlanner};
2648+
use arrow::array::{
2649+
Array, DictionaryArray, Int32Array, Int32Builder, RecordBatch, StringArray,
2650+
};
26512651
use arrow::datatypes::{DataType, Field, Fields, Schema};
26522652
use datafusion::catalog::memory::DataSourceExec;
26532653
use datafusion::datasource::listing::PartitionedFile;
@@ -2656,17 +2656,24 @@ mod tests {
26562656
FileGroup, FileScanConfigBuilder, FileSource, ParquetSource,
26572657
};
26582658
use datafusion::error::DataFusionError;
2659+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
26592660
use datafusion::logical_expr::ScalarUDF;
26602661
use datafusion::physical_plan::ExecutionPlan;
2662+
use datafusion::prelude::SessionConfig;
26612663
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
2664+
use futures::{poll, StreamExt};
2665+
use std::hash::Hasher;
2666+
use std::time::{Duration, Instant};
2667+
use std::{sync::Arc, task::Poll};
26622668
use tempfile::TempDir;
2669+
use tokio::runtime::Runtime;
26632670
use tokio::sync::mpsc;
2664-
2665-
use crate::execution::{operators::InputBatch, planner::PhysicalPlanner};
2671+
use tokio::time::timeout;
26662672

26672673
use crate::execution::operators::ExecutionError;
26682674
use crate::parquet::parquet_support::SparkParquetOptions;
26692675
use crate::parquet::schema_adapter::SparkSchemaAdapterFactory;
2676+
use datafusion_comet_proto::spark_expression::data_type::{ListInfo, StructInfo};
26702677
use datafusion_comet_proto::spark_expression::expr::ExprStruct;
26712678
use datafusion_comet_proto::{
26722679
spark_expression::expr::ExprStruct::*,
@@ -3448,4 +3455,294 @@ mod tests {
34483455
assert_batches_eq!(expected, &[actual]);
34493456
Ok(())
34503457
}
3458+
3459+
#[test]
3460+
fn test_map_grouping_aggregation_using_spark_plan() {
3461+
use arrow::array::{Int64Builder, StringBuilder};
3462+
use arrow::datatypes::{Field, Schema};
3463+
use datafusion::prelude::SessionContext;
3464+
use datafusion_comet_proto::spark_expression::{
3465+
agg_expr::ExprStruct as AggExprStruct,
3466+
data_type::{DataTypeInfo, MapInfo},
3467+
expr::ExprStruct,
3468+
literal::Value,
3469+
AggExpr, BoundReference, Count, DataType as SparkDataType, Expr, Literal,
3470+
};
3471+
use datafusion_comet_proto::spark_operator::{
3472+
operator::OpStruct, AggregateMode, HashAggregate, Operator,
3473+
};
3474+
use std::sync::Arc;
3475+
eprintln!("Testing Map grouping support using Spark plan...");
3476+
3477+
// Create map array using Arrow's MapBuilder
3478+
let string_builder = StringBuilder::new();
3479+
let int_builder = Int64Builder::new();
3480+
let mut map_builder = arrow::array::MapBuilder::new(None, string_builder, int_builder);
3481+
3482+
// Add some map entries to test grouping:
3483+
// Row 0: {"a": 1, b: 2}
3484+
map_builder.keys().append_value("a");
3485+
map_builder.values().append_value(1);
3486+
map_builder.keys().append_value("b");
3487+
map_builder.values().append_value(2);
3488+
map_builder.append(true).unwrap();
3489+
3490+
// Row 1: {"b": 2, a: 1}
3491+
map_builder.keys().append_value("a");
3492+
map_builder.values().append_value(1);
3493+
map_builder.keys().append_value("b");
3494+
map_builder.values().append_value(2);
3495+
map_builder.append(true).unwrap();
3496+
3497+
// Row 2: {"a": 1} - same as row 0, should group together
3498+
map_builder.keys().append_value("a");
3499+
map_builder.values().append_value(1);
3500+
map_builder.append(true).unwrap();
3501+
3502+
// Row 3: {"c": 3}
3503+
map_builder.keys().append_value("c");
3504+
map_builder.values().append_value(3);
3505+
map_builder.append(true).unwrap();
3506+
3507+
let map_array = map_builder.finish();
3508+
3509+
// Create schema - need a scan operator first
3510+
let schema = Arc::new(Schema::new(vec![Field::new(
3511+
"map_col",
3512+
map_array.data_type().clone(),
3513+
false,
3514+
)]));
3515+
3516+
// Create Spark protobuf structures for HashAgg
3517+
let map_type = SparkDataType {
3518+
type_id: 15, // MAP type ID (correct enum value)
3519+
type_info: Some(Box::new(DataTypeInfo {
3520+
datatype_struct: Some(
3521+
spark_expression::data_type::data_type_info::DatatypeStruct::Map(Box::new(
3522+
MapInfo {
3523+
key_type: Some(Box::new(SparkDataType {
3524+
type_id: 7, // STRING type ID (correct enum value)
3525+
type_info: None,
3526+
})),
3527+
value_type: Some(Box::new(SparkDataType {
3528+
type_id: 3, // INT32 type ID (correct enum value)
3529+
type_info: None,
3530+
})),
3531+
value_contains_null: true,
3532+
},
3533+
)),
3534+
),
3535+
})),
3536+
};
3537+
3538+
let key_dt = SparkDataType {
3539+
type_id: 7, // STRING in your enum
3540+
type_info: None,
3541+
};
3542+
let value_dt = SparkDataType {
3543+
type_id: 4, // INT64 in your enum
3544+
type_info: None,
3545+
};
3546+
let value_contains_null: bool = true; // from MapInfo.value_contains_null
3547+
3548+
// 1) Struct<key, value>
3549+
let entries_struct = SparkDataType {
3550+
type_id: 16, // STRUCT
3551+
type_info: Some(Box::new(DataTypeInfo {
3552+
datatype_struct: Some(
3553+
spark_expression::data_type::data_type_info::DatatypeStruct::Struct(
3554+
StructInfo {
3555+
field_names: vec!["keys".to_string(), "values".to_string()],
3556+
field_datatypes: vec![
3557+
key_dt.clone(), // key type K
3558+
value_dt.clone(), // value type V
3559+
],
3560+
field_nullable: vec![
3561+
false, // key is non-null (Arrow/Spark rule)
3562+
value_contains_null, // value nullability copied from Map
3563+
],
3564+
},
3565+
),
3566+
),
3567+
})),
3568+
};
3569+
3570+
// 2) Array<List> around the entries struct: contains_null = false
3571+
let list_of_entries = SparkDataType {
3572+
type_id: 14, // ARRAY
3573+
type_info: Some(Box::new(DataTypeInfo {
3574+
datatype_struct: Some(
3575+
spark_expression::data_type::data_type_info::DatatypeStruct::List(Box::new(
3576+
ListInfo {
3577+
element_type: Some(Box::new(entries_struct)),
3578+
contains_null: false, // entries are never null
3579+
},
3580+
)),
3581+
),
3582+
})),
3583+
};
3584+
3585+
// Create BoundReference for map column (index 0)
3586+
let bound_ref = Expr {
3587+
expr_struct: Some(Bound(BoundReference {
3588+
index: 0,
3589+
datatype: Some(map_type.clone()),
3590+
})),
3591+
};
3592+
3593+
// Create Count aggregation expression
3594+
let count_agg = AggExpr {
3595+
expr_struct: Some(AggExprStruct::Count(Count {
3596+
children: vec![Expr {
3597+
expr_struct: Some(ExprStruct::Literal(Literal {
3598+
datatype: Some(SparkDataType {
3599+
type_id: 3, // INT32 type ID
3600+
type_info: None,
3601+
}),
3602+
is_null: false,
3603+
value: Some(Value::IntVal(1)),
3604+
})),
3605+
}],
3606+
})),
3607+
};
3608+
3609+
let map_sort_expr = Expr {
3610+
expr_struct: Some(ScalarFunc(spark_expression::ScalarFunc {
3611+
func: "map_sort".to_string(),
3612+
args: vec![bound_ref],
3613+
return_type: Some(map_type.clone()),
3614+
})),
3615+
};
3616+
3617+
let map_canonicalize_expr = Expr {
3618+
expr_struct: Some(ScalarFunc(spark_expression::ScalarFunc {
3619+
func: "map_canonicalize".to_string(),
3620+
args: vec![map_sort_expr],
3621+
return_type: Some(list_of_entries.clone()),
3622+
})),
3623+
};
3624+
3625+
println!("Map canonicalize expression: {:?}", map_canonicalize_expr);
3626+
3627+
// Create HashAggregate protobuf structure
3628+
let hash_agg = HashAggregate {
3629+
grouping_exprs: vec![map_canonicalize_expr],
3630+
agg_exprs: vec![count_agg],
3631+
result_exprs: vec![],
3632+
mode: AggregateMode::Partial as i32,
3633+
};
3634+
3635+
// Create a Scan child operator
3636+
let child_op = Operator {
3637+
plan_id: 1,
3638+
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
3639+
fields: vec![map_type],
3640+
source: "test_scan".to_string(),
3641+
})),
3642+
children: vec![],
3643+
};
3644+
3645+
// Create the HashAgg operator
3646+
let hash_agg_op = Operator {
3647+
plan_id: 2,
3648+
op_struct: Some(OpStruct::HashAgg(hash_agg)),
3649+
children: vec![child_op],
3650+
};
3651+
3652+
// Build context / planner
3653+
let ctx = Arc::new(SessionContext::new());
3654+
let planner = PhysicalPlanner::new(ctx.clone(), 0);
3655+
3656+
// Create plan
3657+
let mut inputs = Vec::new();
3658+
let (mut scans, exec_plan) = planner.create_plan(&hash_agg_op, &mut inputs, 1).unwrap();
3659+
3660+
// Build input batch and set on scan
3661+
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(map_array)]).unwrap();
3662+
if let Some(scan) = scans.get_mut(0) {
3663+
use crate::execution::operators::InputBatch;
3664+
let columns: Vec<Arc<dyn Array>> = batch.columns().iter().cloned().collect();
3665+
let input_batch = InputBatch::new(columns, Some(batch.num_rows()));
3666+
scan.set_input_batch(input_batch);
3667+
}
3668+
3669+
// Execute plan
3670+
let task_ctx = ctx.task_ctx();
3671+
let mut stream = exec_plan.native_plan.execute(0, task_ctx).unwrap();
3672+
3673+
// Create a runtime to drive non-blocking polls and draining
3674+
let runtime = Runtime::new().unwrap();
3675+
3676+
// Force one poll to consume the data batch from ScanStream
3677+
runtime.block_on(async {
3678+
let _ = poll!(stream.next());
3679+
});
3680+
3681+
// Signal EOF to finalize aggregation
3682+
if let Some(scan) = scans.get_mut(0) {
3683+
use crate::execution::operators::InputBatch;
3684+
scan.set_input_batch(InputBatch::EOF);
3685+
}
3686+
3687+
// Drain stream to completion and collect results
3688+
let result_batches = runtime.block_on(async move {
3689+
let mut collected = Vec::new();
3690+
while let Some(item) = stream.next().await {
3691+
match item {
3692+
Ok(batch) => {
3693+
println!("Received batch: {:?}", batch);
3694+
collected.push(batch);
3695+
}
3696+
Err(e) => panic!("stream error: {e:?}"),
3697+
}
3698+
}
3699+
collected
3700+
});
3701+
3702+
assert!(
3703+
!result_batches.is_empty(),
3704+
"Should have at least one result batch"
3705+
);
3706+
3707+
// Verify expected groups and counts directly
3708+
use arrow::array::{Int64Array, ListArray, StringArray, StructArray};
3709+
use std::collections::HashMap;
3710+
3711+
// Build actual map: "k=v,k=v" -> count
3712+
let mut actual: HashMap<String, i64> = HashMap::new();
3713+
for rb in &result_batches {
3714+
let groups = rb.column(0).as_any().downcast_ref::<ListArray>().unwrap();
3715+
let counts = rb.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
3716+
3717+
for row in 0..rb.num_rows() {
3718+
// Extract key/value pairs for this group and canonicalize
3719+
let sub = groups.value(row);
3720+
let sa = sub.as_any().downcast_ref::<StructArray>().unwrap();
3721+
let keys = sa.column(0).as_any().downcast_ref::<StringArray>().unwrap();
3722+
let vals = sa.column(1).as_any().downcast_ref::<Int64Array>().unwrap();
3723+
3724+
let mut pairs: Vec<(String, i64)> = (0..sa.len())
3725+
.filter(|&i| sa.is_valid(i))
3726+
.map(|i| (keys.value(i).to_string(), vals.value(i)))
3727+
.collect();
3728+
pairs.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
3729+
let key = pairs
3730+
.into_iter()
3731+
.map(|(k, v)| format!("{k}={v}"))
3732+
.collect::<Vec<_>>()
3733+
.join(",");
3734+
3735+
let cnt = counts.value(row);
3736+
actual.insert(key, cnt);
3737+
}
3738+
}
3739+
3740+
// Expected groups
3741+
let mut expected: HashMap<String, i64> = HashMap::new();
3742+
expected.insert("a=1,b=2".to_string(), 2);
3743+
expected.insert("a=1".to_string(), 1);
3744+
expected.insert("c=3".to_string(), 1);
3745+
3746+
assert_eq!(actual, expected);
3747+
}
34513748
}

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::*;
19+
use crate::map_funcs::{map_canonicalize, spark_map_sort};
1920
use crate::math_funcs::modulo_expr::spark_modulo;
2021
use crate::{
2122
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
@@ -34,7 +35,6 @@ use datafusion::physical_plan::ColumnarValue;
3435
use std::any::Any;
3536
use std::fmt::Debug;
3637
use std::sync::Arc;
37-
use crate::map_funcs::spark_map_sort;
3838

3939
macro_rules! make_comet_scalar_udf {
4040
($name:expr, $func:ident, $data_type:ident) => {{
@@ -149,6 +149,10 @@ pub fn create_comet_physical_fun(
149149
let func = Arc::new(spark_map_sort);
150150
make_comet_scalar_udf!("spark_map_sort", func, without data_type)
151151
}
152+
"map_canonicalize" => {
153+
let func = Arc::new(map_canonicalize);
154+
make_comet_scalar_udf!("map_canonicalize", func, without data_type)
155+
}
152156
_ => registry.udf(fun_name).map_err(|e| {
153157
DataFusionError::Execution(format!(
154158
"Function {fun_name} not found in the registry: {e}",

0 commit comments

Comments
 (0)