Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression

// A builder to create `Expression` from function information.
trait FunctionExpressionBuilder {
// `name` and `clazz` are the name and provided class of user-defined functions, respectively.
// `input` is the children of `ScalaUDAF` or `ScalaAggregator`.
def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1920,4 +1920,9 @@ object QueryExecutionErrors {
s". To solve this try to set $maxDynamicPartitionsKey" +
s" to at least $numWrittenParts.")
}

def registerFunctionWithoutParameterlessConstructorError(className: String): Throwable = {
new RuntimeException(s"Register aggregate function with '$className' which not provides " +
"parameterless constructor is not supported")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,33 @@
*/
package org.apache.spark.sql.internal

import java.io.Serializable

import scala.reflect.ClassTag
import scala.reflect.runtime.universe

import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF}
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog}
import org.apache.spark.sql.execution.streaming.ResolveWriteToStream
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.functions.udaf
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager

Expand Down Expand Up @@ -411,6 +419,38 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder {
name, expr.inputTypes.size.toString, input.size)
}
expr
} else if (classOf[Aggregator[_, _, _]].isAssignableFrom(clazz)) {
val noParameterConstructor = clazz.getConstructors.find(_.getParameterCount == 0)
if (noParameterConstructor.isEmpty) {
throw QueryExecutionErrors.registerFunctionWithoutParameterlessConstructorError(
clazz.getCanonicalName)
}
val aggregator =
noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Serializable, Any, Any]]

// Construct the input encoder
val mirror = universe.runtimeMirror(clazz.getClassLoader)
val classType = mirror.classSymbol(clazz)
val baseClassType = universe.typeOf[Aggregator[Serializable, Any, Any]].typeSymbol.asClass
val baseType = universe.internal.thisType(classType).baseType(baseClassType)
val tpe = baseType.typeArgs.head
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
val cls = mirror.runtimeClass(tpe)
val inputEncoder =
new ExpressionEncoder[Serializable](serializer, deserializer, ClassTag(cls))

val udf: UserDefinedFunction = udaf[Serializable, Any, Any](aggregator, inputEncoder)
assert(udf.isInstanceOf[UserDefinedAggregator[_, _, _]])
val udfAgg: UserDefinedAggregator[_, _, _] = udf.asInstanceOf[UserDefinedAggregator[_, _, _]]

val expr = udfAgg.scalaAggregator(input)
// Check input argument size
if (expr.inputTypes.size != input.size) {
throw QueryCompilationErrors.invalidFunctionArgumentsError(
name, expr.inputTypes.size.toString, input.size)
}
expr
} else {
throw QueryCompilationErrors.noHandlerForUDAFError(clazz.getCanonicalName)
}
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/udaf.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,22 @@ CREATE FUNCTION udaf1 AS 'test.non.existent.udaf';

SELECT default.udaf1(int_col1) as udaf1 from t1;

CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage';

SELECT default.myDoubleAverage(int_col1) as my_avg from t1;

SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1;

CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage';

SELECT default.myDoubleAverage2(int_col1) as my_avg from t1;

CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum';

SELECT default.MyDoubleSum(int_col1) as my_sum from t1;

DROP FUNCTION myDoubleAvg;
DROP FUNCTION udaf1;
DROP FUNCTION myDoubleAverage;
DROP FUNCTION myDoubleAverage2;
DROP FUNCTION MyDoubleSum;
85 changes: 84 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/udaf.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 8
-- Number of queries: 18


-- !query
Expand Down Expand Up @@ -54,6 +54,65 @@ org.apache.spark.sql.AnalysisException
Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7


-- !query
CREATE FUNCTION myDoubleAverage AS 'org.apache.spark.sql.MyDoubleAverage'
-- !query schema
struct<>
-- !query output



-- !query
SELECT default.myDoubleAverage(int_col1) as my_avg from t1
-- !query schema
struct<my_avg:double>
-- !query output
102.5


-- !query
SELECT default.myDoubleAverage(int_col1, 3) as my_avg from t1
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
Invalid number of arguments for function default.myDoubleAverage. Expected: 1; Found: 2; line 1 pos 7


-- !query
CREATE FUNCTION myDoubleAverage2 AS 'test.org.apache.spark.sql.MyDoubleAverage'
-- !query schema
struct<>
-- !query output



-- !query
SELECT default.myDoubleAverage2(int_col1) as my_avg from t1
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
Can not load class 'test.org.apache.spark.sql.MyDoubleAverage' when registering the function 'default.myDoubleAverage2', please make sure it is on the classpath; line 1 pos 7


-- !query
CREATE FUNCTION MyDoubleSum AS 'org.apache.spark.sql.MyDoubleSum'
-- !query schema
struct<>
-- !query output



-- !query
SELECT default.MyDoubleSum(int_col1) as my_sum from t1
-- !query schema
struct<>
-- !query output
java.lang.RuntimeException
Register aggregate function with 'org.apache.spark.sql.MyDoubleSum' which not provides parameterless constructor is not supported


-- !query
DROP FUNCTION myDoubleAvg
-- !query schema
Expand All @@ -68,3 +127,27 @@ DROP FUNCTION udaf1
struct<>
-- !query output



-- !query
DROP FUNCTION myDoubleAverage
-- !query schema
struct<>
-- !query output



-- !query
DROP FUNCTION myDoubleAverage2
-- !query schema
struct<>
-- !query output



-- !query
DROP FUNCTION MyDoubleSum
-- !query schema
struct<>
-- !query output

50 changes: 50 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import java.lang.{Double => jlDouble}
import java.math.BigDecimal
import java.sql.Timestamp
import java.time.{Instant, LocalDate}
Expand Down Expand Up @@ -50,6 +51,24 @@ private case class FunctionResult(f1: String, f2: String)
private case class LocalDateInstantType(date: LocalDate, instant: Instant)
private case class TimestampInstantType(t: Timestamp, instant: Instant)

class MyDoubleAverage extends Aggregator[jlDouble, (Double, Long), jlDouble] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few more stuff I'd like to test:

  1. What if the class has type parameters? e.g. class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long]. What happens if we register it as a SQL UDAF?
  2. Let's test negative cases, e.g., wrong num of parameters, wrong parameter types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. the type parameter will be IN, we need find another way.

def zero: (Double, Long) = (0.0, 0L)
def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = {
if (a != null) (b._1 + a, b._2 + 1L) else b
}
def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) =
(b1._1 + b2._1, b1._2 + b2._2)
def finish(r: (Double, Long)): jlDouble =
if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null
def bufferEncoder: Encoder[(Double, Long)] =
Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong)
def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE
}

class MyDoubleSum(test: Boolean) extends MyDoubleAverage {
override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else null
}

class UDFSuite extends QueryTest with SharedSparkSession {
import testImplicits._

Expand Down Expand Up @@ -848,6 +867,37 @@ class UDFSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-37018: Spark SQL should support create function with Aggregator") {
val avgFuncClass = "org.apache.spark.sql.MyDoubleAverage"
val avgFunction = "my_avg"
val sumFuncClass = "org.apache.spark.sql.MyDoubleSum"
val sumFunction = "my_sum"
withTempDatabase { dbName =>
withUserDefinedFunction(
s"default.$avgFunction" -> false,
s"default.$sumFunction" -> false,
s"$dbName.$avgFunction" -> false,
s"$dbName.$sumFunction" -> false,
avgFunction -> true,
sumFunction -> true) {
// create a function in default database
sql("USE DEFAULT")
sql(s"CREATE FUNCTION $avgFunction AS '$avgFuncClass'")
sql(s"CREATE FUNCTION $sumFunction AS '$sumFuncClass'")
// create a view using a function in 'default' database
withView("v1") {
sql(s"CREATE VIEW v1 AS SELECT $avgFunction(col1) AS func FROM VALUES (1), (2), (3)")
checkAnswer(sql("SELECT * FROM v1"), Seq(Row(102.0)))

val e = intercept[RuntimeException] {
sql(s"CREATE VIEW v2 AS SELECT $sumFunction(col1) AS func FROM VALUES (1), (2), (3)")
}
assert(e.getMessage.contains("not provides parameterless constructor is not supported"))
}
}
}
}

test("SPARK-35674: using java.time.LocalDateTime in UDF") {
// Regular case
val input = Seq(java.time.LocalDateTime.parse("2021-01-01T00:00:00")).toDF("dateTime")
Expand Down