Skip to content

Commit 314c191

Browse files
committed
Integrating map_sort with serdes
1 parent 7fc8b3f commit 314c191

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion::physical_plan::ColumnarValue;
3434
use std::any::Any;
3535
use std::fmt::Debug;
3636
use std::sync::Arc;
37+
use crate::map_funcs::spark_map_sort;
3738

3839
macro_rules! make_comet_scalar_udf {
3940
($name:expr, $func:ident, $data_type:ident) => {{
@@ -144,6 +145,10 @@ pub fn create_comet_physical_fun(
144145
let fail_on_error = fail_on_error.unwrap_or(false);
145146
make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
146147
}
148+
"map_sort" => {
149+
let func = Arc::new(spark_map_sort);
150+
make_comet_scalar_udf!("spark_map_sort", func, without data_type)
151+
}
147152
_ => registry.udf(fun_name).map_err(|e| {
148153
DataFusionError::Execution(format!(
149154
"Function {fun_name} not found in the registry: {e}",

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import org.apache.spark.sql.types._
4848
import org.apache.spark.unsafe.types.UTF8String
4949

5050
import org.apache.comet.CometConf
51-
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
51+
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, isSpark40Plus, withInfo}
5252
import org.apache.comet.expressions._
5353
import org.apache.comet.objectstore.NativeConfig
5454
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
@@ -128,6 +128,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
128128
classOf[MapValues] -> CometMapValues,
129129
classOf[MapFromArrays] -> CometMapFromArrays,
130130
classOf[GetMapValue] -> CometMapExtract,
131+
classOf[MapSort] -> CometMapSort,
131132
classOf[GreaterThan] -> CometGreaterThan,
132133
classOf[GreaterThanOrEqual] -> CometGreaterThanOrEqual,
133134
classOf[LessThan] -> CometLessThan,
@@ -1953,10 +1954,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
19531954

19541955
if (groupingExpressions.exists(expr =>
19551956
expr.dataType match {
1956-
case _: MapType => true
1957+
case _: MapType if !isSpark40Plus => true
19571958
case _ => false
19581959
})) {
1959-
withInfo(op, "Grouping on map types is not supported")
1960+
withInfo(op, "Grouping on map types is not supported below Spark 4.0")
19601961
return None
19611962
}
19621963

spark/src/main/scala/org/apache/comet/serde/maps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,21 @@ object CometMapFromArrays extends CometExpressionSerde {
9494
optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*)
9595
}
9696
}
97+
98+
object CometMapSort extends CometExpressionSerde {
99+
100+
override def convert(
101+
expr: Expression,
102+
inputs: Seq[Attribute],
103+
binding: Boolean): Option[ExprOuterClass.Expr] = {
104+
// scalastyle:off println
105+
println("Calling CometMapSort.convert")
106+
val mapSortExpr = expr.asInstanceOf[MapSort]
107+
val childExpr = exprToProtoInternal(mapSortExpr.child, inputs, binding)
108+
val returnType = mapSortExpr.child.dataType
109+
110+
val mapSortScalarExpr =
111+
scalarFunctionExprToProtoWithReturnType("map_sort", returnType, childExpr)
112+
optExprWithInfo(mapSortScalarExpr, expr, expr.children: _*)
113+
}
114+
}

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.functions.{avg, count_distinct, sum}
3131
import org.apache.spark.sql.internal.SQLConf
3232

3333
import org.apache.comet.CometConf
34+
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
3435
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
3536

3637
/**
@@ -1515,4 +1516,36 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15151516
sparkPlan.collect { case s: CometHashAggregateExec => s }.size
15161517
}
15171518

1519+
test("groupby with map column") {
1520+
assume(isSpark40Plus, "Groupby on map type is supported in Spark 4.0 and beyond")
1521+
// withSQLConf(
1522+
// CometConf.COMET_ENABLED.key -> "false",
1523+
// CometConf.COMET_EXEC_ENABLED.key -> "false",
1524+
// CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false",
1525+
// CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION
1526+
// // CometConf.COMET_SHUFFLE_MODE.key -> "native"
1527+
// )
1528+
// {
1529+
withParquetTable(
1530+
Seq(
1531+
(1, Map("a" -> 1, "b" -> 2)),
1532+
(2, Map("b" -> 2, "a" -> 1)),
1533+
(3, Map("a" -> 5, "b" -> 6))),
1534+
"tbl") {
1535+
withSQLConf(
1536+
CometConf.COMET_ENABLED.key -> "true",
1537+
CometConf.COMET_EXEC_ENABLED.key -> "true",
1538+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true",
1539+
CometConf.COMET_SHUFFLE_MODE.key -> "auto",
1540+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) {
1541+
val query = sql("SELECT count(*) AS testing FROM tbl group by _2")
1542+
println(query.queryExecution.executedPlan)
1543+
query.show()
1544+
// checkSparkAnswer(query)
1545+
}
1546+
// checkSparkAnswerAndOperator("SELECT _1, SUM(_2['b']) FROM tbl GROUP BY _1")
1547+
}
1548+
// }
1549+
}
1550+
15181551
}

0 commit comments

Comments
 (0)