diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala index bf3d790b86c0..256486658efc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression // A builder to create `Expression` from function information. trait FunctionExpressionBuilder { + // `name` and `clazz` are the name and provided class of user-defined functions, respectively. + // `input` is the children of `ScalaUDAF` or `ScalaAggregator`. def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 6f0ed23e1022..c173a33f1a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1920,4 +1920,9 @@ object QueryExecutionErrors { s". To solve this try to set $maxDynamicPartitionsKey" + s" to at least $numWrittenParts.") } + + def registerFunctionWithoutParameterlessConstructorError(className: String): Throwable = { + new RuntimeException(s"Register aggregate function with '$className' which not provides " + + "parameterless constructor is not supported") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 1e60cb8b1db2..e28adc520db1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -16,17 +16,24 @@ */ package org.apache.spark.sql.internal +import java.io.Serializable + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe + import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin @@ -34,7 +41,8 @@ import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} +import org.apache.spark.sql.functions.udaf import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -411,6 +419,38 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { name, expr.inputTypes.size.toString, input.size) } expr + } else if (classOf[Aggregator[_, _, _]].isAssignableFrom(clazz)) { + val noParameterConstructor = clazz.getConstructors.find(_.getParameterCount == 0) + if (noParameterConstructor.isEmpty) { + throw QueryExecutionErrors.registerFunctionWithoutParameterlessConstructorError( + clazz.getCanonicalName) + } + val aggregator = + noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Serializable, Any, Any]] + + // Construct the input encoder + val mirror = universe.runtimeMirror(clazz.getClassLoader) + val classType = mirror.classSymbol(clazz) + val baseClassType = universe.typeOf[Aggregator[Serializable, Any, Any]].typeSymbol.asClass + val baseType = universe.internal.thisType(classType).baseType(baseClassType) + val tpe = baseType.typeArgs.head + val serializer = ScalaReflection.serializerForType(tpe) + val deserializer = ScalaReflection.deserializerForType(tpe) + val cls = mirror.runtimeClass(tpe) + val inputEncoder = + new ExpressionEncoder[Serializable](serializer, deserializer, ClassTag(cls)) + + val udf: UserDefinedFunction = udaf[Serializable, Any, Any](aggregator, inputEncoder) + assert(udf.isInstanceOf[UserDefinedAggregator[_, _, _]]) + val udfAgg: UserDefinedAggregator[_, _, _] = udf.asInstanceOf[UserDefinedAggregator[_, _, _]] + + val expr = udfAgg.scalaAggregator(input) + // Check input argument size + if (expr.inputTypes.size != input.size) { + throw QueryCompilationErrors.invalidFunctionArgumentsError( + name, expr.inputTypes.size.toString, input.size) + } + expr } else { throw QueryCompilationErrors.noHandlerForUDAFError(clazz.getCanonicalName) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 0374d98feb6e..8e8eb2e8793f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -17,5 +17,22 @@ CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; SELECT default.udaf1(int_col1) as udaf1 from t1; +CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage'; + +SELECT default.myDoubleAverage(int_col1) as my_avg from t1; + +SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1; + +CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage'; + +SELECT default.myDoubleAverage2(int_col1) as my_avg from t1; + +CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum'; + +SELECT default.MyDoubleSum(int_col1) as my_sum from t1; + DROP FUNCTION myDoubleAvg; DROP FUNCTION udaf1; +DROP FUNCTION myDoubleAverage; +DROP FUNCTION myDoubleAverage2; +DROP FUNCTION MyDoubleSum; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 9f4229a11b65..b36f3949a61f 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 18 -- !query @@ -54,6 +54,65 @@ org.apache.spark.sql.AnalysisException Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 +-- !query +CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.myDoubleAverage(int_col1) as my_avg from t1 +-- !query schema +struct +-- !query output +102.5 + + +-- !query +SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function default.myDoubleAverage. Expected: 1; Found: 2; line 1 pos 7 + + +-- !query +CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.myDoubleAverage2(int_col1) as my_avg from t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Can not load class 'test.org.apache.spark.sql.MyDoubleAverage' when registering the function 'default.myDoubleAverage2', please make sure it is on the classpath; line 1 pos 7 + + +-- !query +CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT default.MyDoubleSum(int_col1) as my_sum from t1 +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +Register aggregate function with 'org.apache.spark.sql.MyDoubleSum' which not provides parameterless constructor is not supported + + -- !query DROP FUNCTION myDoubleAvg -- !query schema @@ -68,3 +127,27 @@ DROP FUNCTION udaf1 struct<> -- !query output + + +-- !query +DROP FUNCTION myDoubleAverage +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP FUNCTION myDoubleAverage2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP FUNCTION MyDoubleSum +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d100cad89fcc..e0e4dc51cfd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.lang.{Double => jlDouble} import java.math.BigDecimal import java.sql.Timestamp import java.time.{Instant, LocalDate} @@ -50,6 +51,24 @@ private case class FunctionResult(f1: String, f2: String) private case class LocalDateInstantType(date: LocalDate, instant: Instant) private case class TimestampInstantType(t: Timestamp, instant: Instant) +class MyDoubleAverage extends Aggregator[jlDouble, (Double, Long), jlDouble] { + def zero: (Double, Long) = (0.0, 0L) + def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = { + if (a != null) (b._1 + a, b._2 + 1L) else b + } + def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = + (b1._1 + b2._1, b1._2 + b2._2) + def finish(r: (Double, Long)): jlDouble = + if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null + def bufferEncoder: Encoder[(Double, Long)] = + Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong) + def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE +} + +class MyDoubleSum(test: Boolean) extends MyDoubleAverage { + override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else null +} + class UDFSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -848,6 +867,37 @@ class UDFSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-37018: Spark SQL should support create function with Aggregator") { + val avgFuncClass = "org.apache.spark.sql.MyDoubleAverage" + val avgFunction = "my_avg" + val sumFuncClass = "org.apache.spark.sql.MyDoubleSum" + val sumFunction = "my_sum" + withTempDatabase { dbName => + withUserDefinedFunction( + s"default.$avgFunction" -> false, + s"default.$sumFunction" -> false, + s"$dbName.$avgFunction" -> false, + s"$dbName.$sumFunction" -> false, + avgFunction -> true, + sumFunction -> true) { + // create a function in default database + sql("USE DEFAULT") + sql(s"CREATE FUNCTION $avgFunction AS '$avgFuncClass'") + sql(s"CREATE FUNCTION $sumFunction AS '$sumFuncClass'") + // create a view using a function in 'default' database + withView("v1") { + sql(s"CREATE VIEW v1 AS SELECT $avgFunction(col1) AS func FROM VALUES (1), (2), (3)") + checkAnswer(sql("SELECT * FROM v1"), Seq(Row(102.0))) + + val e = intercept[RuntimeException] { + sql(s"CREATE VIEW v2 AS SELECT $sumFunction(col1) AS func FROM VALUES (1), (2), (3)") + } + assert(e.getMessage.contains("not provides parameterless constructor is not supported")) + } + } + } + } + test("SPARK-35674: using java.time.LocalDateTime in UDF") { // Regular case val input = Seq(java.time.LocalDateTime.parse("2021-01-01T00:00:00")).toDF("dateTime")