From 5f7a0e2f0444bd030de12d51c1b5e961404f7ac7 Mon Sep 17 00:00:00 2001 From: can wang Date: Sun, 22 Aug 2021 15:54:55 +0800 Subject: [PATCH 01/18] copy code from https://github.com/JetBrains/kotlin-spark-api/pull/67 --- .../sql/catalyst/encoders/RowEncoder.scala | 336 ++++ .../kotlinx/spark/api/SparkHelper.kt | 2 + .../kotlinx/spark/api/UDFRegister.kt | 1347 +++++++++++++++++ .../kotlinx/spark/api/UDFRegisterTest.kt | 124 ++ 4 files changed, 1809 insertions(+) create mode 100644 core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala create mode 100644 kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt create mode 100644 kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt diff --git a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala new file mode 100644 index 00000000..8129db06 --- /dev/null +++ b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -0,0 +1,336 @@ +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataTypeWithClass, Row} +import org.apache.spark.unsafe.types.UTF8String + +import scala.annotation.tailrec +import scala.collection.Map +import scala.reflect.ClassTag + +/** + * A factory for constructing encoders that convert external row to/from the Spark SQL + * internal binary representation. + * + * The following is a mapping between Spark SQL types and its allowed external types: + * {{{ + * BooleanType -> java.lang.Boolean + * ByteType -> java.lang.Byte + * ShortType -> java.lang.Short + * IntegerType -> java.lang.Integer + * FloatType -> java.lang.Float + * DoubleType -> java.lang.Double + * StringType -> String + * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal + * + * DateType -> java.sql.Date + * TimestampType -> java.sql.Timestamp + * + * BinaryType -> byte array + * ArrayType -> scala.collection.Seq or Array + * MapType -> scala.collection.Map + * StructType -> org.apache.spark.sql.Row + * }}} + */ +object RowEncoder { + def apply(schema: StructType): ExpressionEncoder[Row] = { + val cls = classOf[Row] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val updatedSchema = schema.copy( + fields = schema.fields.map(f => f.dataType match { + case kstw: DataTypeWithClass => f.copy(dataType = kstw.dt, nullable = kstw.nullable) + case _ => f + }) + ) + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), updatedSchema) + val deserializer = deserializerFor(updatedSchema) + new ExpressionEncoder[Row]( + updatedSchema, + flat = false, + serializer.asInstanceOf[CreateNamedStruct].flatten, + deserializer, + ClassTag(cls)) + } + + private def serializerFor( + inputObject: Expression, + inputType: DataType): Expression = inputType match { + case dt if ScalaReflection.isNativeType(dt) => inputObject + + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) + + case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } + val obj = NewInstance( + udtClass, + Nil, + dataType = ObjectType(udtClass), false) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) + + case TimestampType => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil, + returnNullable = false) + + case DateType => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil, + returnNullable = false) + + case d: DecimalType => + CheckOverflow(StaticInvoke( + Decimal.getClass, + d, + "fromDecimal", + inputObject :: Nil, + returnNullable = false), d) + + case StringType => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + + case t@ArrayType(et, containsNull) => + et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + StaticInvoke( + classOf[ArrayData], + t, + "toArrayData", + inputObject :: Nil, + returnNullable = false) + + case _ => MapObjects( + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, + inputObject, + ObjectType(classOf[Object])) + } + + case t@MapType(kt, vt, valueNullable) => + val keys = + Invoke( + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) + + val nonNullOutput = NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } + + case StructType(fields) => + val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => + val dataType = field.dataType + val fieldValue = serializerFor( + ValidateExternalType( + GetExternalRowField(inputObject, index, field.name), + dataType), + dataType) + val convertedField = if (field.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), + Literal.create(null, dataType), + fieldValue + ) + } else { + fieldValue + } + Literal(field.name) :: convertedField :: Nil + }) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } + } + + /** + * Returns the `DataType` that can be used when generating code that converts input data + * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned + * by this function can be more permissive since multiple external types may map to a single + * internal type. For example, for an input with DecimalType in external row, its external types + * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or + * `org.apache.spark.sql.types.Decimal`. + */ + def externalDataTypeForInput(dt: DataType): DataType = dt match { + case _ => dt match { + case dtwc: DataTypeWithClass => dtwc.dt + // In order to support both Decimal and java/scala BigDecimal in external row, we make this + // as java.lang.Object. + case _: DecimalType => ObjectType(classOf[Object]) + // In order to support both Array and Seq in external row, we make this as java.lang.Object. + case _: ArrayType => ObjectType(classOf[Object]) + case _ => externalDataTypeFor(dt) + } + } + + @tailrec + def externalDataTypeFor(dt: DataType): DataType = dt match { + case kstw: DataTypeWithClass => externalDataTypeFor(kstw.dt) + case _ if ScalaReflection.isNativeType(dt) => dt + case TimestampType => ObjectType(classOf[java.sql.Timestamp]) + case DateType => ObjectType(classOf[java.sql.Date]) + case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) + case StringType => ObjectType(classOf[java.lang.String]) + case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) + case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) + case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) + } + + private def deserializerFor(schema: StructType): Expression = { + val fields = schema.zipWithIndex.map { case (f, i) => + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + deserializerFor(GetColumnByOrdinal(i, dt)) + } + + CreateExternalRow(fields, schema) + } + + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { + case dt if ScalaReflection.isNativeType(dt) => input + case kstw: DataTypeWithClass => deserializerFor(input, kstw.dt) + + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) + + case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } + val obj = NewInstance( + udtClass, + Nil, + dataType = ObjectType(udtClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + + case TimestampType => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + input :: Nil, + returnNullable = false) + + case DateType => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + input :: Nil, + returnNullable = false) + + case _: DecimalType => + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) + + case StringType => + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) + + case ArrayType(et, nullable) => + val arrayData = + Invoke( + MapObjects(deserializerFor(_), input, et), + "array", + ObjectType(classOf[Array[_]]), returnNullable = false) + StaticInvoke( + scala.collection.mutable.WrappedArray.getClass, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil, + returnNullable = false) + + case MapType(kt, vt, valueNullable) => + val keyArrayType = ArrayType(kt, false) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) + + val valueArrayType = ArrayType(vt, valueNullable) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil, + returnNullable = false) + + case schema@StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + If( + Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, externalDataTypeFor(f.dataType match { + case kstw: DataTypeWithClass => kstw.dt + case o => o + })), + deserializerFor(GetStructField(input, i))) + } + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields, schema)) + } +} diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt index d5d6aa2c..3ef0b177 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt @@ -20,6 +20,7 @@ package org.jetbrains.kotlinx.spark.api import org.apache.spark.sql.SparkSession.Builder +import org.apache.spark.sql.UDFRegistration import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR /** @@ -78,4 +79,5 @@ inline class KSparkSession(val spark: SparkSession) { inline fun List.toDS() = toDS(spark) inline fun Array.toDS() = spark.dsOf(*this) inline fun dsOf(vararg arg: T) = spark.dsOf(*arg) + val udf: UDFRegistration get() = spark.udf() } diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt new file mode 100644 index 00000000..f091491a --- /dev/null +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -0,0 +1,1347 @@ +@file:Suppress("DuplicatedCode") + +package org.jetbrains.kotlinx.spark.api + +import org.apache.spark.sql.Column +import org.apache.spark.sql.UDFRegistration +import org.apache.spark.sql.api.java.* +import org.apache.spark.sql.functions +import scala.collection.mutable.WrappedArray +import kotlin.reflect.KClass +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.typeOf + +/** + * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException] + */ +@PublishedApi +internal fun KClass<*>.checkForValidType(parameterName: String) { + if (this == String::class || isSubclassOf(WrappedArray::class)) return // Most of the time we need strings or WrappedArrays + if (isSubclassOf(Iterable::class) || java.isArray + || isSubclassOf(Map::class) || isSubclassOf(Array::class) + || isSubclassOf(ByteArray::class) || isSubclassOf(CharArray::class) + || isSubclassOf(ShortArray::class) || isSubclassOf(IntArray::class) + || isSubclassOf(LongArray::class) || isSubclassOf(FloatArray::class) + || isSubclassOf(DoubleArray::class) || isSubclassOf(BooleanArray::class) + ) { + throw TypeOfUDFParameterNotSupportedException(this, parameterName) + } +} + +/** + * An exception thrown when the UDF is generated with illegal types for the parameters + */ +class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: String) : IllegalArgumentException( + "Parameter $parameterName is subclass of ${kClass.qualifiedName}. If you need to process an array use ${WrappedArray::class.qualifiedName}." +) + +/** + * A wrapper for an UDF with 0 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper0(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(): Column { + return functions.callUDF(udfName) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 { + register(name, UDF0(func), schema(typeOf())) + return UDFWrapper0(name) +} + +/** + * A wrapper for an UDF with 1 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper1(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column): Column { + return functions.callUDF(udfName, param0) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 { + T0::class.checkForValidType("T0") + register(name, UDF1(func), schema(typeOf())) + return UDFWrapper1(name) +} + +/** + * A wrapper for an UDF with 2 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper2(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column): Column { + return functions.callUDF(udfName, param0, param1) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1) -> R +): UDFWrapper2 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + register(name, UDF2(func), schema(typeOf())) + return UDFWrapper2(name) +} + +/** + * A wrapper for an UDF with 3 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper3(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column): Column { + return functions.callUDF(udfName, param0, param1, param2) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2) -> R +): UDFWrapper3 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + register(name, UDF3(func), schema(typeOf())) + return UDFWrapper3(name) +} + +/** + * A wrapper for an UDF with 4 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper4(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3) -> R +): UDFWrapper4 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + register(name, UDF4(func), schema(typeOf())) + return UDFWrapper4(name) +} + +/** + * A wrapper for an UDF with 5 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper5(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4) -> R +): UDFWrapper5 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + register(name, UDF5(func), schema(typeOf())) + return UDFWrapper5(name) +} + +/** + * A wrapper for an UDF with 6 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper6(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5) -> R +): UDFWrapper6 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + register(name, UDF6(func), schema(typeOf())) + return UDFWrapper6(name) +} + +/** + * A wrapper for an UDF with 7 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper7(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6) -> R +): UDFWrapper7 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + register(name, UDF7(func), schema(typeOf())) + return UDFWrapper7(name) +} + +/** + * A wrapper for an UDF with 8 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper8(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7) -> R +): UDFWrapper8 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + register(name, UDF8(func), schema(typeOf())) + return UDFWrapper8(name) +} + +/** + * A wrapper for an UDF with 9 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper9(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8) -> R +): UDFWrapper9 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + register(name, UDF9(func), schema(typeOf())) + return UDFWrapper9(name) +} + +/** + * A wrapper for an UDF with 10 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper10(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> R +): UDFWrapper10 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + register(name, UDF10(func), schema(typeOf())) + return UDFWrapper10(name) +} + +/** + * A wrapper for an UDF with 11 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper11(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> R +): UDFWrapper11 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + register(name, UDF11(func), schema(typeOf())) + return UDFWrapper11(name) +} + +/** + * A wrapper for an UDF with 12 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper12(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> R +): UDFWrapper12 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + register(name, UDF12(func), schema(typeOf())) + return UDFWrapper12(name) +} + +/** + * A wrapper for an UDF with 13 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper13(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> R +): UDFWrapper13 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + register(name, UDF13(func), schema(typeOf())) + return UDFWrapper13(name) +} + +/** + * A wrapper for an UDF with 14 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper14(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) -> R +): UDFWrapper14 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + register(name, UDF14(func), schema(typeOf())) + return UDFWrapper14(name) +} + +/** + * A wrapper for an UDF with 15 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper15(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) -> R +): UDFWrapper15 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + register(name, UDF15(func), schema(typeOf())) + return UDFWrapper15(name) +} + +/** + * A wrapper for an UDF with 16 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper16(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15) -> R +): UDFWrapper16 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + register(name, UDF16(func), schema(typeOf())) + return UDFWrapper16(name) +} + +/** + * A wrapper for an UDF with 17 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper17(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> R +): UDFWrapper17 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + register(name, UDF17(func), schema(typeOf())) + return UDFWrapper17(name) +} + +/** + * A wrapper for an UDF with 18 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper18(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17) -> R +): UDFWrapper18 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + register(name, UDF18(func), schema(typeOf())) + return UDFWrapper18(name) +} + +/** + * A wrapper for an UDF with 19 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper19(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18) -> R +): UDFWrapper19 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + register(name, UDF19(func), schema(typeOf())) + return UDFWrapper19(name) +} + +/** + * A wrapper for an UDF with 20 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper20(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19) -> R +): UDFWrapper20 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + register(name, UDF20(func), schema(typeOf())) + return UDFWrapper20(name) +} + +/** + * A wrapper for an UDF with 21 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper21(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column, + param20: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19, + param20 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20) -> R +): UDFWrapper21 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + register(name, UDF21(func), schema(typeOf())) + return UDFWrapper21(name) +} + +/** + * A wrapper for an UDF with 22 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper22(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column, + param20: Column, + param21: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19, + param20, + param21 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21) -> R +): UDFWrapper22 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + T21::class.checkForValidType("T21") + register(name, UDF22(func), schema(typeOf())) + return UDFWrapper22(name) +} \ No newline at end of file diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt new file mode 100644 index 00000000..0d2e8759 --- /dev/null +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -0,0 +1,124 @@ +package org.jetbrains.kotlinx.spark.api + +import io.kotest.core.spec.style.ShouldSpec +import org.apache.spark.sql.RowFactory +import org.apache.spark.sql.types.DataTypes +import org.junit.jupiter.api.assertThrows +import scala.collection.JavaConversions +import scala.collection.mutable.WrappedArray + +private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this) + +class UDFRegisterTest : ShouldSpec({ + context("org.jetbrains.kotlinx.spark.api.UDFRegister") { + context("the function checkForValidType") { + val invalidTypes = listOf( + Array::class, + Iterable::class, + List::class, + MutableList::class, + ByteArray::class, + CharArray::class, + ShortArray::class, + IntArray::class, + LongArray::class, + FloatArray::class, + DoubleArray::class, + BooleanArray::class, + Map::class, + MutableMap::class, + Set::class, + MutableSet::class, + arrayOf("")::class, + listOf("")::class, + setOf("")::class, + mapOf("" to "")::class, + mutableListOf("")::class, + mutableSetOf("")::class, + mutableMapOf("" to "")::class, + ) + invalidTypes.forEachIndexed { index, invalidType -> + should("$index: throw an ${TypeOfUDFParameterNotSupportedException::class.simpleName} when encountering ${invalidType.qualifiedName}") { + assertThrows { + invalidType.checkForValidType("test") + } + } + } + } + + context("the register-function") { + withSpark { + + should("fail when using a simple kotlin.Array") { + assertThrows { + udf.register("shouldFail") { array: Array -> + array.joinToString(" ") + } + } + } + + should("succeed when using a WrappedArray") { + udf.register("shouldSucceed") { array: WrappedArray -> + array.asIterable().joinToString(" ") + } + } + } + } + + context("calling the UDF-Wrapper") { + withSpark(logLevel = SparkLogLevel.DEBUG) { + should("succeed when using the right number of arguments") { + val schema = DataTypes.createStructType( + listOf( + DataTypes.createStructField( + "textArray", + DataTypes.createArrayType(DataTypes.StringType), + false + ), + DataTypes.createStructField("id", DataTypes.StringType, false) + ) + ) + + val rows = listOf( + RowFactory.create(arrayOf("a", "b", "c"), "1"), + RowFactory.create(arrayOf("d", "e", "f"), "2"), + RowFactory.create(arrayOf("g", "h", "i"), "3"), + ) + + val testData = spark.createDataFrame(rows, schema) + + val stringArrayMerger = udf.register, String>("stringArrayMerger") { + it.asIterable().joinToString(" ") + } + + val newData = testData.withColumn("text", stringArrayMerger(testData.col("textArray"))) + + newData.select("text").collectAsList().zip(newData.select("textArray").collectAsList()) + .forEach { (text, textArray) -> + assert(text.getString(0) == textArray.getList(0).joinToString(" ")) + } + } +// should("also work with datasets") { +// val ds = listOf("a" to 1, "b" to 2).toDS() +// val stringIntDiff = udf.register>("stringIntDiff") { a, b -> +// c(a[0].toInt() - b) +// } +// val lst = ds.withColumn("new", stringIntDiff(ds.col("first"), ds.col("second"))) +// .select("new") +// .collectAsList() +// +//// val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() +//// expect(result).asExpect().contains.inOrder.only.values(96, 96) +// } + should("also work with datasets") { + listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") + udf.register("stringIntDiff") { a, b -> + a[0].toInt() - b + } + spark.sql("select stringIntDiff(first, second) from test1").show() + + } + } + } + } +}) \ No newline at end of file From 2c161cd4b2a0951d2afdd6f1e340d8490fabd872 Mon Sep 17 00:00:00 2001 From: can wang Date: Mon, 23 Aug 2021 23:44:22 +0800 Subject: [PATCH 02/18] remove hacked RowEncoder.scala --- .../sql/catalyst/encoders/RowEncoder.scala | 336 ------------------ 1 file changed, 336 deletions(-) delete mode 100644 core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala diff --git a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala deleted file mode 100644 index 8129db06..00000000 --- a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ /dev/null @@ -1,336 +0,0 @@ -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataTypeWithClass, Row} -import org.apache.spark.unsafe.types.UTF8String - -import scala.annotation.tailrec -import scala.collection.Map -import scala.reflect.ClassTag - -/** - * A factory for constructing encoders that convert external row to/from the Spark SQL - * internal binary representation. - * - * The following is a mapping between Spark SQL types and its allowed external types: - * {{{ - * BooleanType -> java.lang.Boolean - * ByteType -> java.lang.Byte - * ShortType -> java.lang.Short - * IntegerType -> java.lang.Integer - * FloatType -> java.lang.Float - * DoubleType -> java.lang.Double - * StringType -> String - * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal - * - * DateType -> java.sql.Date - * TimestampType -> java.sql.Timestamp - * - * BinaryType -> byte array - * ArrayType -> scala.collection.Seq or Array - * MapType -> scala.collection.Map - * StructType -> org.apache.spark.sql.Row - * }}} - */ -object RowEncoder { - def apply(schema: StructType): ExpressionEncoder[Row] = { - val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val updatedSchema = schema.copy( - fields = schema.fields.map(f => f.dataType match { - case kstw: DataTypeWithClass => f.copy(dataType = kstw.dt, nullable = kstw.nullable) - case _ => f - }) - ) - val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), updatedSchema) - val deserializer = deserializerFor(updatedSchema) - new ExpressionEncoder[Row]( - updatedSchema, - flat = false, - serializer.asInstanceOf[CreateNamedStruct].flatten, - deserializer, - ClassTag(cls)) - } - - private def serializerFor( - inputObject: Expression, - inputType: DataType): Expression = inputType match { - case dt if ScalaReflection.isNativeType(dt) => inputObject - - case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) - - case udt: UserDefinedType[_] => - val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) - val udtClass: Class[_] = if (annotation != null) { - annotation.udt() - } else { - UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { - throw new SparkException(s"${udt.userClass.getName} is not annotated with " + - "SQLUserDefinedType nor registered with UDTRegistration.}") - } - } - val obj = NewInstance( - udtClass, - Nil, - dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) - - case TimestampType => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil, - returnNullable = false) - - case DateType => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil, - returnNullable = false) - - case d: DecimalType => - CheckOverflow(StaticInvoke( - Decimal.getClass, - d, - "fromDecimal", - inputObject :: Nil, - returnNullable = false), d) - - case StringType => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil, - returnNullable = false) - - case t@ArrayType(et, containsNull) => - et match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - StaticInvoke( - classOf[ArrayData], - t, - "toArrayData", - inputObject :: Nil, - returnNullable = false) - - case _ => MapObjects( - element => { - val value = serializerFor(ValidateExternalType(element, et), et) - if (!containsNull) { - AssertNotNull(value, Seq.empty) - } else { - value - } - }, - inputObject, - ObjectType(classOf[Object])) - } - - case t@MapType(kt, vt, valueNullable) => - val keys = - Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), - returnNullable = false), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedKeys = serializerFor(keys, ArrayType(kt, false)) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), - returnNullable = false), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) - val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) - - val nonNullOutput = NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = t, - propagateNull = false) - - if (inputObject.nullable) { - If(IsNull(inputObject), - Literal.create(null, inputType), - nonNullOutput) - } else { - nonNullOutput - } - - case StructType(fields) => - val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => - val dataType = field.dataType - val fieldValue = serializerFor( - ValidateExternalType( - GetExternalRowField(inputObject, index, field.name), - dataType), - dataType) - val convertedField = if (field.nullable) { - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), - Literal.create(null, dataType), - fieldValue - ) - } else { - fieldValue - } - Literal(field.name) :: convertedField :: Nil - }) - - if (inputObject.nullable) { - If(IsNull(inputObject), - Literal.create(null, inputType), - nonNullOutput) - } else { - nonNullOutput - } - } - - /** - * Returns the `DataType` that can be used when generating code that converts input data - * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned - * by this function can be more permissive since multiple external types may map to a single - * internal type. For example, for an input with DecimalType in external row, its external types - * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or - * `org.apache.spark.sql.types.Decimal`. - */ - def externalDataTypeForInput(dt: DataType): DataType = dt match { - case _ => dt match { - case dtwc: DataTypeWithClass => dtwc.dt - // In order to support both Decimal and java/scala BigDecimal in external row, we make this - // as java.lang.Object. - case _: DecimalType => ObjectType(classOf[Object]) - // In order to support both Array and Seq in external row, we make this as java.lang.Object. - case _: ArrayType => ObjectType(classOf[Object]) - case _ => externalDataTypeFor(dt) - } - } - - @tailrec - def externalDataTypeFor(dt: DataType): DataType = dt match { - case kstw: DataTypeWithClass => externalDataTypeFor(kstw.dt) - case _ if ScalaReflection.isNativeType(dt) => dt - case TimestampType => ObjectType(classOf[java.sql.Timestamp]) - case DateType => ObjectType(classOf[java.sql.Date]) - case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) - case StringType => ObjectType(classOf[java.lang.String]) - case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) - case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) - case _: StructType => ObjectType(classOf[Row]) - case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) - case udt: UserDefinedType[_] => ObjectType(udt.userClass) - } - - private def deserializerFor(schema: StructType): Expression = { - val fields = schema.zipWithIndex.map { case (f, i) => - val dt = f.dataType match { - case p: PythonUserDefinedType => p.sqlType - case other => other - } - deserializerFor(GetColumnByOrdinal(i, dt)) - } - - CreateExternalRow(fields, schema) - } - - private def deserializerFor(input: Expression): Expression = { - deserializerFor(input, input.dataType) - } - - private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { - case dt if ScalaReflection.isNativeType(dt) => input - case kstw: DataTypeWithClass => deserializerFor(input, kstw.dt) - - case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) - - case udt: UserDefinedType[_] => - val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) - val udtClass: Class[_] = if (annotation != null) { - annotation.udt() - } else { - UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { - throw new SparkException(s"${udt.userClass.getName} is not annotated with " + - "SQLUserDefinedType nor registered with UDTRegistration.}") - } - } - val obj = NewInstance( - udtClass, - Nil, - dataType = ObjectType(udtClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) - - case TimestampType => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - input :: Nil, - returnNullable = false) - - case DateType => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - input :: Nil, - returnNullable = false) - - case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), - returnNullable = false) - - case StringType => - Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) - - case ArrayType(et, nullable) => - val arrayData = - Invoke( - MapObjects(deserializerFor(_), input, et), - "array", - ObjectType(classOf[Array[_]]), returnNullable = false) - StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil, - returnNullable = false) - - case MapType(kt, vt, valueNullable) => - val keyArrayType = ArrayType(kt, false) - val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) - - val valueArrayType = ArrayType(vt, valueNullable) - val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case schema@StructType(fields) => - val convertedFields = fields.zipWithIndex.map { case (f, i) => - If( - Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, externalDataTypeFor(f.dataType match { - case kstw: DataTypeWithClass => kstw.dt - case o => o - })), - deserializerFor(GetStructField(input, i))) - } - If(IsNull(input), - Literal.create(null, externalDataTypeFor(input.dataType)), - CreateExternalRow(convertedFields, schema)) - } -} From 8bbea9a280f4205489f4fd5ace98d3f559ef46f9 Mon Sep 17 00:00:00 2001 From: can wang Date: Mon, 23 Aug 2021 23:54:40 +0800 Subject: [PATCH 03/18] replace all schema(typeOf()) to DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json()) --- .../kotlinx/spark/api/UDFRegister.kt | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt index f091491a..2b9e3e54 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -3,9 +3,11 @@ package org.jetbrains.kotlinx.spark.api import org.apache.spark.sql.Column +import org.apache.spark.sql.DataTypeWithClass import org.apache.spark.sql.UDFRegistration import org.apache.spark.sql.api.java.* import org.apache.spark.sql.functions +import org.apache.spark.sql.types.DataType import scala.collection.mutable.WrappedArray import kotlin.reflect.KClass import kotlin.reflect.full.isSubclassOf @@ -53,7 +55,7 @@ class UDFWrapper0(private val udfName: String) { */ @OptIn(ExperimentalStdlibApi::class) inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 { - register(name, UDF0(func), schema(typeOf())) + register(name, UDF0(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper0(name) } @@ -76,7 +78,7 @@ class UDFWrapper1(private val udfName: String) { @OptIn(ExperimentalStdlibApi::class) inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 { T0::class.checkForValidType("T0") - register(name, UDF1(func), schema(typeOf())) + register(name, UDF1(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper1(name) } @@ -103,7 +105,7 @@ inline fun UDFRegistration.register( ): UDFWrapper2 { T0::class.checkForValidType("T0") T1::class.checkForValidType("T1") - register(name, UDF2(func), schema(typeOf())) + register(name, UDF2(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper2(name) } @@ -131,7 +133,7 @@ inline fun UDFRegistration.regis T0::class.checkForValidType("T0") T1::class.checkForValidType("T1") T2::class.checkForValidType("T2") - register(name, UDF3(func), schema(typeOf())) + register(name, UDF3(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper3(name) } @@ -160,7 +162,7 @@ inline fun UDFRegist T1::class.checkForValidType("T1") T2::class.checkForValidType("T2") T3::class.checkForValidType("T3") - register(name, UDF4(func), schema(typeOf())) + register(name, UDF4(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper4(name) } @@ -190,7 +192,7 @@ inline fun ())) + register(name, UDF5(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper5(name) } @@ -228,7 +230,7 @@ inline fun ())) + register(name, UDF6(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper6(name) } @@ -268,7 +270,7 @@ inline fun ())) + register(name, UDF7(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper7(name) } @@ -310,7 +312,7 @@ inline fun ())) + register(name, UDF8(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper8(name) } @@ -354,7 +356,7 @@ inline fun ())) + register(name, UDF9(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper9(name) } @@ -412,7 +414,7 @@ inline fun ())) + register(name, UDF10(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper10(name) } @@ -473,7 +475,7 @@ inline fun ())) + register(name, UDF11(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper11(name) } @@ -537,7 +539,7 @@ inline fun ())) + register(name, UDF12(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper12(name) } @@ -604,7 +606,7 @@ inline fun ())) + register(name, UDF13(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper13(name) } @@ -674,7 +676,7 @@ inline fun ())) + register(name, UDF14(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper14(name) } @@ -747,7 +749,7 @@ inline fun ())) + register(name, UDF15(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper15(name) } @@ -823,7 +825,7 @@ inline fun ())) + register(name, UDF16(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper16(name) } @@ -902,7 +904,7 @@ inline fun ())) + register(name, UDF17(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper17(name) } @@ -984,7 +986,7 @@ inline fun ())) + register(name, UDF18(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper18(name) } @@ -1069,7 +1071,7 @@ inline fun ())) + register(name, UDF19(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper19(name) } @@ -1157,7 +1159,7 @@ inline fun ())) + register(name, UDF20(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper20(name) } @@ -1248,7 +1250,7 @@ inline fun ())) + register(name, UDF21(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper21(name) } @@ -1342,6 +1344,6 @@ inline fun ())) + register(name, UDF22(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) return UDFWrapper22(name) } \ No newline at end of file From a4bbbf910fe032dc0dd79ebfec3075e4d0678f7d Mon Sep 17 00:00:00 2001 From: can wang Date: Mon, 23 Aug 2021 23:55:43 +0800 Subject: [PATCH 04/18] add return udf data class test --- .../kotlinx/spark/api/UDFRegisterTest.kt | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 0d2e8759..470a4af3 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -120,5 +120,23 @@ class UDFRegisterTest : ShouldSpec({ } } } + +// context("udf return data class") { +// withSpark(logLevel = SparkLogLevel.DEBUG) { +// should("return NormalClass") { +// listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2") +// udf.register("toNormalClass") { a, b -> +// NormalClass(a,b) +// } +// spark.sql("select toNormalClass(first, second) from test2").show() +// } +// } +// } + } -}) \ No newline at end of file +}) + +data class NormalClass( + val name: String, + val age: Int +) \ No newline at end of file From 59cec6b85e643a28fbaaf371e02fa22868f39c57 Mon Sep 17 00:00:00 2001 From: can wang Date: Tue, 24 Aug 2021 00:04:22 +0800 Subject: [PATCH 05/18] change test --- .../kotlinx/spark/api/UDFRegisterTest.kt | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 470a4af3..17e34a1c 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -1,6 +1,7 @@ package org.jetbrains.kotlinx.spark.api import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe import org.apache.spark.sql.RowFactory import org.apache.spark.sql.types.DataTypes import org.junit.jupiter.api.assertThrows @@ -62,6 +63,15 @@ class UDFRegisterTest : ShouldSpec({ array.asIterable().joinToString(" ") } } + + should("succeed when using three type udf and as result to udf return type") { + listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") + udf.register("stringIntDiff") { a, b -> + a[0].toInt() - b + } + val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() + result shouldBe listOf(96, 96) + } } } @@ -98,26 +108,6 @@ class UDFRegisterTest : ShouldSpec({ assert(text.getString(0) == textArray.getList(0).joinToString(" ")) } } -// should("also work with datasets") { -// val ds = listOf("a" to 1, "b" to 2).toDS() -// val stringIntDiff = udf.register>("stringIntDiff") { a, b -> -// c(a[0].toInt() - b) -// } -// val lst = ds.withColumn("new", stringIntDiff(ds.col("first"), ds.col("second"))) -// .select("new") -// .collectAsList() -// -//// val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() -//// expect(result).asExpect().contains.inOrder.only.values(96, 96) -// } - should("also work with datasets") { - listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") - udf.register("stringIntDiff") { a, b -> - a[0].toInt() - b - } - spark.sql("select stringIntDiff(first, second) from test1").show() - - } } } From 460350f0b9011517795558e151ec1886d03f41be Mon Sep 17 00:00:00 2001 From: can wang Date: Tue, 24 Aug 2021 00:39:33 +0800 Subject: [PATCH 06/18] add in dataset test for calling the UDF-Wrapper --- .../kotlinx/spark/api/UDFRegisterTest.kt | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 17e34a1c..c2b51f2d 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.spark.api import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe +import org.apache.spark.sql.Dataset import org.apache.spark.sql.RowFactory import org.apache.spark.sql.types.DataTypes import org.junit.jupiter.api.assertThrows @@ -108,6 +109,25 @@ class UDFRegisterTest : ShouldSpec({ assert(text.getString(0) == textArray.getList(0).joinToString(" ")) } } + + + should("succeed in dataset") { + val dataset: Dataset = listOf(NormalClass("a", 10), NormalClass("b", 20)).toDS() + + val udfWrapper = udf.register("nameConcatAge") { name, age -> + "$name-$age" + } + + val collectAsList = dataset.withColumn( + "nameAndAge", + udfWrapper(dataset.col("name"), dataset.col("age")) + ) + .select("nameAndAge") + .collectAsList() + + collectAsList[0][0] shouldBe "a-10" + collectAsList[1][0] shouldBe "b-20" + } } } From f2506949418cb284d872f0ed78095ebf6d4a683c Mon Sep 17 00:00:00 2001 From: can wang Date: Tue, 24 Aug 2021 10:28:14 +0800 Subject: [PATCH 07/18] add the same exception link --- .../kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index c2b51f2d..660ba165 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -131,6 +131,7 @@ class UDFRegisterTest : ShouldSpec({ } } + // get the same exception with: https://forums.databricks.com/questions/13361/how-do-i-create-a-udf-in-java-which-return-a-compl.html // context("udf return data class") { // withSpark(logLevel = SparkLogLevel.DEBUG) { // should("return NormalClass") { From 42b5aa3221af57749b55c841166140601d815ace Mon Sep 17 00:00:00 2001 From: can wang Date: Tue, 24 Aug 2021 10:38:49 +0800 Subject: [PATCH 08/18] refactor unWrapper --- .../kotlinx/spark/api/UDFRegister.kt | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt index 2b9e3e54..111e6231 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -13,6 +13,13 @@ import kotlin.reflect.KClass import kotlin.reflect.full.isSubclassOf import kotlin.reflect.typeOf +fun DataType.unWrapper(): DataType { + return when (this) { + is DataTypeWithClass -> DataType.fromJson(dt().json()) + else -> this + } +} + /** * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException] */ @@ -55,7 +62,7 @@ class UDFWrapper0(private val udfName: String) { */ @OptIn(ExperimentalStdlibApi::class) inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 { - register(name, UDF0(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) + register(name, UDF0(func), schema(typeOf()).unWrapper()) return UDFWrapper0(name) } @@ -78,7 +85,7 @@ class UDFWrapper1(private val udfName: String) { @OptIn(ExperimentalStdlibApi::class) inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 { T0::class.checkForValidType("T0") - register(name, UDF1(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) + register(name, UDF1(func), schema(typeOf()).unWrapper()) return UDFWrapper1(name) } @@ -105,7 +112,7 @@ inline fun UDFRegistration.register( ): UDFWrapper2 { T0::class.checkForValidType("T0") T1::class.checkForValidType("T1") - register(name, UDF2(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) + register(name, UDF2(func), schema(typeOf()).unWrapper()) return UDFWrapper2(name) } @@ -133,7 +140,7 @@ inline fun UDFRegistration.regis T0::class.checkForValidType("T0") T1::class.checkForValidType("T1") T2::class.checkForValidType("T2") - register(name, UDF3(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) + register(name, UDF3(func), schema(typeOf()).unWrapper()) return UDFWrapper3(name) } @@ -162,7 +169,7 @@ inline fun UDFRegist T1::class.checkForValidType("T1") T2::class.checkForValidType("T2") T3::class.checkForValidType("T3") - register(name, UDF4(func), DataType.fromJson((schema(typeOf()) as DataTypeWithClass).dt().json())) + register(name, UDF4(func), schema(typeOf()).unWrapper()) return UDFWrapper4(name) } @@ -192,7 +199,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF5(func), schema(typeOf()).unWrapper()) return UDFWrapper5(name) } @@ -230,7 +237,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF6(func), schema(typeOf()).unWrapper()) return UDFWrapper6(name) } @@ -270,7 +277,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF7(func), schema(typeOf()).unWrapper()) return UDFWrapper7(name) } @@ -312,7 +319,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF8(func), schema(typeOf()).unWrapper()) return UDFWrapper8(name) } @@ -356,7 +363,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF9(func), schema(typeOf()).unWrapper()) return UDFWrapper9(name) } @@ -414,7 +421,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF10(func), schema(typeOf()).unWrapper()) return UDFWrapper10(name) } @@ -475,7 +482,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF11(func), schema(typeOf()).unWrapper()) return UDFWrapper11(name) } @@ -539,7 +546,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF12(func), schema(typeOf()).unWrapper()) return UDFWrapper12(name) } @@ -606,7 +613,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF13(func), schema(typeOf()).unWrapper()) return UDFWrapper13(name) } @@ -676,7 +683,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF14(func), schema(typeOf()).unWrapper()) return UDFWrapper14(name) } @@ -749,7 +756,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF15(func), schema(typeOf()).unWrapper()) return UDFWrapper15(name) } @@ -825,7 +832,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF16(func), schema(typeOf()).unWrapper()) return UDFWrapper16(name) } @@ -904,7 +911,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF17(func), schema(typeOf()).unWrapper()) return UDFWrapper17(name) } @@ -986,7 +993,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF18(func), schema(typeOf()).unWrapper()) return UDFWrapper18(name) } @@ -1071,7 +1078,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF19(func), schema(typeOf()).unWrapper()) return UDFWrapper19(name) } @@ -1159,7 +1166,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF20(func), schema(typeOf()).unWrapper()) return UDFWrapper20(name) } @@ -1250,7 +1257,7 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF21(func), schema(typeOf()).unWrapper()) return UDFWrapper21(name) } @@ -1344,6 +1351,6 @@ inline fun ()) as DataTypeWithClass).dt().json())) + register(name, UDF22(func), schema(typeOf()).unWrapper()) return UDFWrapper22(name) } \ No newline at end of file From 456a29e3041b8cdf42b035b04af097066edd6fcf Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 25 Aug 2021 21:28:30 +0800 Subject: [PATCH 09/18] add test for udf return a List --- .../org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 660ba165..6bfefe3f 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -65,6 +65,15 @@ class UDFRegisterTest : ShouldSpec({ } } + should("succeed when return a List") { + udf.register>("StringToIntList") { a -> + a.asIterable().map { it.toInt() } + } + + val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList() + result shouldBe listOf(listOf(97, 98)) + } + should("succeed when using three type udf and as result to udf return type") { listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") udf.register("stringIntDiff") { a, b -> From 6fe2c6b36dd05582acd71ee4c34b2cd892b667dc Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 1 Sep 2021 20:27:47 +0800 Subject: [PATCH 10/18] make the test simpler --- .../kotlinx/spark/api/UDFRegisterTest.kt | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 6bfefe3f..6185a12a 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -87,33 +87,16 @@ class UDFRegisterTest : ShouldSpec({ context("calling the UDF-Wrapper") { withSpark(logLevel = SparkLogLevel.DEBUG) { - should("succeed when using the right number of arguments") { - val schema = DataTypes.createStructType( - listOf( - DataTypes.createStructField( - "textArray", - DataTypes.createArrayType(DataTypes.StringType), - false - ), - DataTypes.createStructField("id", DataTypes.StringType, false) - ) - ) - - val rows = listOf( - RowFactory.create(arrayOf("a", "b", "c"), "1"), - RowFactory.create(arrayOf("d", "e", "f"), "2"), - RowFactory.create(arrayOf("g", "h", "i"), "3"), - ) - - val testData = spark.createDataFrame(rows, schema) + should("succeed call UDF-Wrapper in withColumn") { val stringArrayMerger = udf.register, String>("stringArrayMerger") { it.asIterable().joinToString(" ") } - val newData = testData.withColumn("text", stringArrayMerger(testData.col("textArray"))) + val testData = dsOf(listOf("a", "b")) + val newData = testData.withColumn("text", stringArrayMerger(testData.col("value"))) - newData.select("text").collectAsList().zip(newData.select("textArray").collectAsList()) + newData.select("text").collectAsList().zip(newData.select("value").collectAsList()) .forEach { (text, textArray) -> assert(text.getString(0) == textArray.getList(0).joinToString(" ")) } From e77c58d593274f974b7ded3442dc4c5e881e518b Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 1 Sep 2021 20:37:04 +0800 Subject: [PATCH 11/18] add License --- .../kotlinx/spark/api/UDFRegister.kt | 21 ++++++++++++++++++- .../kotlinx/spark/api/UDFRegisterTest.kt | 21 ++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt index 111e6231..6dc19d58 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -1,3 +1,22 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ @file:Suppress("DuplicatedCode") package org.jetbrains.kotlinx.spark.api @@ -1353,4 +1372,4 @@ inline fun ()).unWrapper()) return UDFWrapper22(name) -} \ No newline at end of file +} diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 6185a12a..fca04d46 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -1,3 +1,22 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ package org.jetbrains.kotlinx.spark.api import io.kotest.core.spec.style.ShouldSpec @@ -142,4 +161,4 @@ class UDFRegisterTest : ShouldSpec({ data class NormalClass( val name: String, val age: Int -) \ No newline at end of file +) From 12ea1bdb4c556c897ff5ab5d855ef9f58ddedf04 Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 1 Sep 2021 20:47:55 +0800 Subject: [PATCH 12/18] add UDFRegister for 3.0 --- .../kotlinx/spark/api/SparkHelper.kt | 2 + .../kotlinx/spark/api/UDFRegister.kt | 1375 +++++++++++++++++ .../kotlinx/spark/api/UDFRegisterTest.kt | 164 ++ 3 files changed, 1541 insertions(+) create mode 100644 kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt create mode 100644 kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt index d5d6aa2c..3ef0b177 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt @@ -20,6 +20,7 @@ package org.jetbrains.kotlinx.spark.api import org.apache.spark.sql.SparkSession.Builder +import org.apache.spark.sql.UDFRegistration import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR /** @@ -78,4 +79,5 @@ inline class KSparkSession(val spark: SparkSession) { inline fun List.toDS() = toDS(spark) inline fun Array.toDS() = spark.dsOf(*this) inline fun dsOf(vararg arg: T) = spark.dsOf(*arg) + val udf: UDFRegistration get() = spark.udf() } diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt new file mode 100644 index 00000000..6dc19d58 --- /dev/null +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -0,0 +1,1375 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +@file:Suppress("DuplicatedCode") + +package org.jetbrains.kotlinx.spark.api + +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataTypeWithClass +import org.apache.spark.sql.UDFRegistration +import org.apache.spark.sql.api.java.* +import org.apache.spark.sql.functions +import org.apache.spark.sql.types.DataType +import scala.collection.mutable.WrappedArray +import kotlin.reflect.KClass +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.typeOf + +fun DataType.unWrapper(): DataType { + return when (this) { + is DataTypeWithClass -> DataType.fromJson(dt().json()) + else -> this + } +} + +/** + * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException] + */ +@PublishedApi +internal fun KClass<*>.checkForValidType(parameterName: String) { + if (this == String::class || isSubclassOf(WrappedArray::class)) return // Most of the time we need strings or WrappedArrays + if (isSubclassOf(Iterable::class) || java.isArray + || isSubclassOf(Map::class) || isSubclassOf(Array::class) + || isSubclassOf(ByteArray::class) || isSubclassOf(CharArray::class) + || isSubclassOf(ShortArray::class) || isSubclassOf(IntArray::class) + || isSubclassOf(LongArray::class) || isSubclassOf(FloatArray::class) + || isSubclassOf(DoubleArray::class) || isSubclassOf(BooleanArray::class) + ) { + throw TypeOfUDFParameterNotSupportedException(this, parameterName) + } +} + +/** + * An exception thrown when the UDF is generated with illegal types for the parameters + */ +class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: String) : IllegalArgumentException( + "Parameter $parameterName is subclass of ${kClass.qualifiedName}. If you need to process an array use ${WrappedArray::class.qualifiedName}." +) + +/** + * A wrapper for an UDF with 0 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper0(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(): Column { + return functions.callUDF(udfName) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 { + register(name, UDF0(func), schema(typeOf()).unWrapper()) + return UDFWrapper0(name) +} + +/** + * A wrapper for an UDF with 1 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper1(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column): Column { + return functions.callUDF(udfName, param0) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 { + T0::class.checkForValidType("T0") + register(name, UDF1(func), schema(typeOf()).unWrapper()) + return UDFWrapper1(name) +} + +/** + * A wrapper for an UDF with 2 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper2(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column): Column { + return functions.callUDF(udfName, param0, param1) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1) -> R +): UDFWrapper2 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + register(name, UDF2(func), schema(typeOf()).unWrapper()) + return UDFWrapper2(name) +} + +/** + * A wrapper for an UDF with 3 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper3(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column): Column { + return functions.callUDF(udfName, param0, param1, param2) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2) -> R +): UDFWrapper3 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + register(name, UDF3(func), schema(typeOf()).unWrapper()) + return UDFWrapper3(name) +} + +/** + * A wrapper for an UDF with 4 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper4(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3) -> R +): UDFWrapper4 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + register(name, UDF4(func), schema(typeOf()).unWrapper()) + return UDFWrapper4(name) +} + +/** + * A wrapper for an UDF with 5 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper5(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4) -> R +): UDFWrapper5 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + register(name, UDF5(func), schema(typeOf()).unWrapper()) + return UDFWrapper5(name) +} + +/** + * A wrapper for an UDF with 6 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper6(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5) -> R +): UDFWrapper6 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + register(name, UDF6(func), schema(typeOf()).unWrapper()) + return UDFWrapper6(name) +} + +/** + * A wrapper for an UDF with 7 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper7(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6) -> R +): UDFWrapper7 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + register(name, UDF7(func), schema(typeOf()).unWrapper()) + return UDFWrapper7(name) +} + +/** + * A wrapper for an UDF with 8 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper8(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7) -> R +): UDFWrapper8 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + register(name, UDF8(func), schema(typeOf()).unWrapper()) + return UDFWrapper8(name) +} + +/** + * A wrapper for an UDF with 9 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper9(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column + ): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8) -> R +): UDFWrapper9 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + register(name, UDF9(func), schema(typeOf()).unWrapper()) + return UDFWrapper9(name) +} + +/** + * A wrapper for an UDF with 10 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper10(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> R +): UDFWrapper10 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + register(name, UDF10(func), schema(typeOf()).unWrapper()) + return UDFWrapper10(name) +} + +/** + * A wrapper for an UDF with 11 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper11(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> R +): UDFWrapper11 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + register(name, UDF11(func), schema(typeOf()).unWrapper()) + return UDFWrapper11(name) +} + +/** + * A wrapper for an UDF with 12 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper12(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> R +): UDFWrapper12 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + register(name, UDF12(func), schema(typeOf()).unWrapper()) + return UDFWrapper12(name) +} + +/** + * A wrapper for an UDF with 13 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper13(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> R +): UDFWrapper13 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + register(name, UDF13(func), schema(typeOf()).unWrapper()) + return UDFWrapper13(name) +} + +/** + * A wrapper for an UDF with 14 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper14(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) -> R +): UDFWrapper14 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + register(name, UDF14(func), schema(typeOf()).unWrapper()) + return UDFWrapper14(name) +} + +/** + * A wrapper for an UDF with 15 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper15(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) -> R +): UDFWrapper15 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + register(name, UDF15(func), schema(typeOf()).unWrapper()) + return UDFWrapper15(name) +} + +/** + * A wrapper for an UDF with 16 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper16(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15) -> R +): UDFWrapper16 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + register(name, UDF16(func), schema(typeOf()).unWrapper()) + return UDFWrapper16(name) +} + +/** + * A wrapper for an UDF with 17 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper17(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> R +): UDFWrapper17 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + register(name, UDF17(func), schema(typeOf()).unWrapper()) + return UDFWrapper17(name) +} + +/** + * A wrapper for an UDF with 18 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper18(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17) -> R +): UDFWrapper18 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + register(name, UDF18(func), schema(typeOf()).unWrapper()) + return UDFWrapper18(name) +} + +/** + * A wrapper for an UDF with 19 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper19(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18) -> R +): UDFWrapper19 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + register(name, UDF19(func), schema(typeOf()).unWrapper()) + return UDFWrapper19(name) +} + +/** + * A wrapper for an UDF with 20 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper20(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19) -> R +): UDFWrapper20 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + register(name, UDF20(func), schema(typeOf()).unWrapper()) + return UDFWrapper20(name) +} + +/** + * A wrapper for an UDF with 21 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper21(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column, + param20: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19, + param20 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20) -> R +): UDFWrapper21 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + register(name, UDF21(func), schema(typeOf()).unWrapper()) + return UDFWrapper21(name) +} + +/** + * A wrapper for an UDF with 22 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper22(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke( + param0: Column, + param1: Column, + param2: Column, + param3: Column, + param4: Column, + param5: Column, + param6: Column, + param7: Column, + param8: Column, + param9: Column, + param10: Column, + param11: Column, + param12: Column, + param13: Column, + param14: Column, + param15: Column, + param16: Column, + param17: Column, + param18: Column, + param19: Column, + param20: Column, + param21: Column + ): Column { + return functions.callUDF( + udfName, + param0, + param1, + param2, + param3, + param4, + param5, + param6, + param7, + param8, + param9, + param10, + param11, + param12, + param13, + param14, + param15, + param16, + param17, + param18, + param19, + param20, + param21 + ) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register( + name: String, + noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21) -> R +): UDFWrapper22 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + T21::class.checkForValidType("T21") + register(name, UDF22(func), schema(typeOf()).unWrapper()) + return UDFWrapper22(name) +} diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt new file mode 100644 index 00000000..fca04d46 --- /dev/null +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -0,0 +1,164 @@ +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +package org.jetbrains.kotlinx.spark.api + +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.RowFactory +import org.apache.spark.sql.types.DataTypes +import org.junit.jupiter.api.assertThrows +import scala.collection.JavaConversions +import scala.collection.mutable.WrappedArray + +private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this) + +class UDFRegisterTest : ShouldSpec({ + context("org.jetbrains.kotlinx.spark.api.UDFRegister") { + context("the function checkForValidType") { + val invalidTypes = listOf( + Array::class, + Iterable::class, + List::class, + MutableList::class, + ByteArray::class, + CharArray::class, + ShortArray::class, + IntArray::class, + LongArray::class, + FloatArray::class, + DoubleArray::class, + BooleanArray::class, + Map::class, + MutableMap::class, + Set::class, + MutableSet::class, + arrayOf("")::class, + listOf("")::class, + setOf("")::class, + mapOf("" to "")::class, + mutableListOf("")::class, + mutableSetOf("")::class, + mutableMapOf("" to "")::class, + ) + invalidTypes.forEachIndexed { index, invalidType -> + should("$index: throw an ${TypeOfUDFParameterNotSupportedException::class.simpleName} when encountering ${invalidType.qualifiedName}") { + assertThrows { + invalidType.checkForValidType("test") + } + } + } + } + + context("the register-function") { + withSpark { + + should("fail when using a simple kotlin.Array") { + assertThrows { + udf.register("shouldFail") { array: Array -> + array.joinToString(" ") + } + } + } + + should("succeed when using a WrappedArray") { + udf.register("shouldSucceed") { array: WrappedArray -> + array.asIterable().joinToString(" ") + } + } + + should("succeed when return a List") { + udf.register>("StringToIntList") { a -> + a.asIterable().map { it.toInt() } + } + + val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList() + result shouldBe listOf(listOf(97, 98)) + } + + should("succeed when using three type udf and as result to udf return type") { + listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") + udf.register("stringIntDiff") { a, b -> + a[0].toInt() - b + } + val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() + result shouldBe listOf(96, 96) + } + } + } + + context("calling the UDF-Wrapper") { + withSpark(logLevel = SparkLogLevel.DEBUG) { + should("succeed call UDF-Wrapper in withColumn") { + + val stringArrayMerger = udf.register, String>("stringArrayMerger") { + it.asIterable().joinToString(" ") + } + + val testData = dsOf(listOf("a", "b")) + val newData = testData.withColumn("text", stringArrayMerger(testData.col("value"))) + + newData.select("text").collectAsList().zip(newData.select("value").collectAsList()) + .forEach { (text, textArray) -> + assert(text.getString(0) == textArray.getList(0).joinToString(" ")) + } + } + + + should("succeed in dataset") { + val dataset: Dataset = listOf(NormalClass("a", 10), NormalClass("b", 20)).toDS() + + val udfWrapper = udf.register("nameConcatAge") { name, age -> + "$name-$age" + } + + val collectAsList = dataset.withColumn( + "nameAndAge", + udfWrapper(dataset.col("name"), dataset.col("age")) + ) + .select("nameAndAge") + .collectAsList() + + collectAsList[0][0] shouldBe "a-10" + collectAsList[1][0] shouldBe "b-20" + } + } + } + + // get the same exception with: https://forums.databricks.com/questions/13361/how-do-i-create-a-udf-in-java-which-return-a-compl.html +// context("udf return data class") { +// withSpark(logLevel = SparkLogLevel.DEBUG) { +// should("return NormalClass") { +// listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2") +// udf.register("toNormalClass") { a, b -> +// NormalClass(a,b) +// } +// spark.sql("select toNormalClass(first, second) from test2").show() +// } +// } +// } + + } +}) + +data class NormalClass( + val name: String, + val age: Int +) From 02169bdf355eced176055045d001bbd2d030bc0b Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 1 Sep 2021 20:49:13 +0800 Subject: [PATCH 13/18] remove useless import --- .../kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 2 -- .../kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 2 -- 2 files changed, 4 deletions(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index fca04d46..36ed32d6 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -22,8 +22,6 @@ package org.jetbrains.kotlinx.spark.api import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe import org.apache.spark.sql.Dataset -import org.apache.spark.sql.RowFactory -import org.apache.spark.sql.types.DataTypes import org.junit.jupiter.api.assertThrows import scala.collection.JavaConversions import scala.collection.mutable.WrappedArray diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index fca04d46..36ed32d6 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -22,8 +22,6 @@ package org.jetbrains.kotlinx.spark.api import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe import org.apache.spark.sql.Dataset -import org.apache.spark.sql.RowFactory -import org.apache.spark.sql.types.DataTypes import org.junit.jupiter.api.assertThrows import scala.collection.JavaConversions import scala.collection.mutable.WrappedArray From e72aceecd17178f0bf53c280182e892ede688ccf Mon Sep 17 00:00:00 2001 From: can wang Date: Wed, 1 Sep 2021 20:51:21 +0800 Subject: [PATCH 14/18] resolved deprecated method --- .../org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 4 ++-- .../org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 36ed32d6..9afa5a65 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -84,7 +84,7 @@ class UDFRegisterTest : ShouldSpec({ should("succeed when return a List") { udf.register>("StringToIntList") { a -> - a.asIterable().map { it.toInt() } + a.asIterable().map { it.code } } val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList() @@ -94,7 +94,7 @@ class UDFRegisterTest : ShouldSpec({ should("succeed when using three type udf and as result to udf return type") { listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") udf.register("stringIntDiff") { a, b -> - a[0].toInt() - b + a[0].code - b } val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() result shouldBe listOf(96, 96) diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 36ed32d6..aff54044 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -23,10 +23,10 @@ import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe import org.apache.spark.sql.Dataset import org.junit.jupiter.api.assertThrows -import scala.collection.JavaConversions +import scala.collection.JavaConverters import scala.collection.mutable.WrappedArray -private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this) +private fun scala.collection.Iterable.asIterable(): Iterable = JavaConverters.asJavaIterable(this) class UDFRegisterTest : ShouldSpec({ context("org.jetbrains.kotlinx.spark.api.UDFRegister") { @@ -84,7 +84,7 @@ class UDFRegisterTest : ShouldSpec({ should("succeed when return a List") { udf.register>("StringToIntList") { a -> - a.asIterable().map { it.toInt() } + a.asIterable().map { it.code } } val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList() @@ -94,7 +94,7 @@ class UDFRegisterTest : ShouldSpec({ should("succeed when using three type udf and as result to udf return type") { listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") udf.register("stringIntDiff") { a, b -> - a[0].toInt() - b + a[0].code - b } val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() result shouldBe listOf(96, 96) From 3398210c384c7ce7b0d05659de113dd8af3ec6d9 Mon Sep 17 00:00:00 2001 From: can wang Date: Sun, 5 Sep 2021 13:40:42 +0800 Subject: [PATCH 15/18] [experimental] add CatalystTypeConverters.scala for hacked it to implement UDF return data class --- .../sql/catalyst/CatalystTypeConverters.scala | 480 ++++++++++++++++++ 1 file changed, 480 insertions(+) create mode 100644 core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala diff --git a/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala new file mode 100644 index 00000000..7def589a --- /dev/null +++ b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -0,0 +1,480 @@ +package org.apache.spark.sql.catalyst + +import kotlin.jvm.JvmClassMappingKt +import kotlin.reflect.{KClass, KFunction, KProperty1} +import kotlin.reflect.full.KClasses + +import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.math.{BigInteger => JavaBigInteger} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} +import java.util.{Map => JavaMap} +import javax.annotation.Nullable +import scala.language.existentials +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Functions to convert Scala types to Catalyst types and vice versa. + */ +object CatalystTypeConverters { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + + import scala.collection.Map + + private[sql] def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter + case DateType => DateConverter + case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter + case TimestampType => TimestampConverter + case dt: DecimalType => new DecimalConverter(dt) + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case dataType: DataType => IdentityConverter(dataType) + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + + /** + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. + */ + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) + } + } + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: InternalRow, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } + + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType + } + + private case class IdentityConverter(dataType: DataType) + extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + + override def toScala(catalystValue: Any): Any = catalystValue + + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) + } + + private case class UDTConverter[A >: Null]( + udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. + override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue) + + override def toScala(catalystValue: Any): A = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + + override def toScalaImpl(row: InternalRow, column: Int): A = + toScala(row.get(column, udt.sqlType)) + } + + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): ArrayData = { + scalaValue match { + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) + case i: JavaIterable[_] => + val iter = i.iterator + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] + while (iter.hasNext) { + val item = iter.next() + convertedIterable += elementConverter.toCatalyst(item) + } + new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") + } + } + + override def toScala(catalystValue: ArrayData): Seq[Any] = { + if (catalystValue == null) { + null + } else if (isPrimitive(elementType)) { + catalystValue.toArray[Any](elementType) + } else { + val result = new Array[Any](catalystValue.numElements()) + catalystValue.foreach(elementType, (i, e) => { + result(i) = elementConverter.toScala(e) + }) + result + } + } + + override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = + toScala(row.getArray(column)) + } + + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], MapData] { + + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) + + override def toCatalystImpl(scalaValue: Any): MapData = { + val keyFunction = (k: Any) => keyConverter.toCatalyst(k) + val valueFunction = (k: Any) => valueConverter.toCatalyst(k) + + scalaValue match { + case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) + case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") + } + } + + override def toScala(catalystValue: MapData): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + val keys = catalystValue.keyArray().toArray[Any](keyType) + val values = catalystValue.valueArray().toArray[Any](valueType) + val convertedKeys = + if (isPrimitive(keyType)) keys else keys.map(keyConverter.toScala) + val convertedValues = + if (isPrimitive(valueType)) values else values.map(valueConverter.toScala) + + convertedKeys.zip(convertedValues).toMap + } + } + + override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = + toScala(row.getMap(column)) + } + + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] { + + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } + + override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 + } + new GenericInternalRow(ar) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 + } + new GenericInternalRow(ar) + + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") + } + + override def toScala(row: InternalRow): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.numFields) + var idx = 0 + while (idx < row.numFields) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 + } + new GenericRowWithSchema(ar, structType) + } + } + + override def toScalaImpl(row: InternalRow, column: Int): Row = + toScala(row.getStruct(column, structType.size)) + } + + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String.fromString(str) + case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") + } + + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString + + override def toScalaImpl(row: InternalRow, column: Int): String = + row.getUTF8String(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) + + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) + } + + private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] { + override def toCatalystImpl(scalaValue: LocalDate): Int = { + DateTimeUtils.localDateToDays(scalaValue) + } + + override def toScala(catalystValue: Any): LocalDate = { + if (catalystValue == null) null + else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int]) + } + + override def toScalaImpl(row: InternalRow, column: Int): LocalDate = + DateTimeUtils.daysToLocalDate(row.getInt(column)) + } + + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Timestamp): Long = + DateTimeUtils.fromJavaTimestamp(scalaValue) + + override def toScala(catalystValue: Any): Timestamp = + if (catalystValue == null) null + else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) + + override def toScalaImpl(row: InternalRow, column: Int): Timestamp = + DateTimeUtils.toJavaTimestamp(row.getLong(column)) + } + + private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] { + override def toCatalystImpl(scalaValue: Instant): Long = + DateTimeUtils.instantToMicros(scalaValue) + + override def toScala(catalystValue: Any): Instant = + if (catalystValue == null) null + else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long]) + + override def toScalaImpl(row: InternalRow, column: Int): Instant = + DateTimeUtils.microsToInstant(row.getLong(column)) + } + + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + + private val nullOnOverflow = !SQLConf.get.ansiEnabled + + override def toCatalystImpl(scalaValue: Any): Decimal = { + val decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: JavaBigInteger => Decimal(d) + case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") + } + decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow) + } + + override def toScala(catalystValue: Decimal): JavaBigDecimal = { + if (catalystValue == null) null + else catalystValue.toJavaBigDecimal + } + + override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal + } + + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: InternalRow, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: InternalRow, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: InternalRow, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: InternalRow, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: InternalRow, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: InternalRow, column: Int): Float = row.getFloat(column) + } + + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column) + } + + /** + * Creates a converter function that will convert Scala objects to the specified Catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue + } + } + + convert + } else { + getConverterForType(dataType).toCatalyst + } + } + + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala + } + } + + /** + * Converts Scala objects to Catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case ld: LocalDate => LocalDateConverter.toCatalyst(ld) + case t: Timestamp => TimestampConverter.toCatalyst(t) + case i: Instant => InstantConverter.toCatalyst(i) + case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) + case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) + case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) + case map: Map[_, _] => + ArrayBasedMapData( + map, + (key: Any) => convertToCatalyst(key), + (value: Any) => convertToCatalyst(value)) + case other => other + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + createToScalaConverter(dataType)(catalystValue) + } +} From fc85d475ed579573b67d81393b4bd2ed01224ca0 Mon Sep 17 00:00:00 2001 From: can wang Date: Sun, 5 Sep 2021 13:41:09 +0800 Subject: [PATCH 16/18] [experimental] implement UDF return data class --- core/3.0/pom_2.12.xml | 5 ++- .../sql/catalyst/CatalystTypeConverters.scala | 12 +++++++ .../kotlinx/spark/api/UDFRegisterTest.kt | 32 ++++++++++--------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/core/3.0/pom_2.12.xml b/core/3.0/pom_2.12.xml index b06ca56e..c3c2e972 100644 --- a/core/3.0/pom_2.12.xml +++ b/core/3.0/pom_2.12.xml @@ -18,7 +18,10 @@ scala-library ${scala.version} - + + org.jetbrains.kotlin + kotlin-reflect + org.apache.spark diff --git a/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 7def589a..f618cf04 100644 --- a/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -253,6 +253,18 @@ object CatalystTypeConverters { } new GenericInternalRow(ar) + case ktDataClass: Any if JvmClassMappingKt.getKotlinClass(ktDataClass.getClass).isData => + import scala.collection.JavaConverters._ + val klass: KClass[Any] = JvmClassMappingKt.getKotlinClass(ktDataClass.getClass).asInstanceOf[KClass[Any]] + val iter: Iterator[KProperty1[Any,_]] = KClasses.getDeclaredMemberProperties(klass).iterator().asScala + val ar = new Array[Any](structType.size) + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next().get(ktDataClass)) + idx += 1 + } + new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to ${structType.catalogString}") diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index aff54044..84bbd55e 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -121,7 +121,10 @@ class UDFRegisterTest : ShouldSpec({ should("succeed in dataset") { - val dataset: Dataset = listOf(NormalClass("a", 10), NormalClass("b", 20)).toDS() + val dataset: Dataset = listOf( + NormalClass(name="a", age =10), + NormalClass(name="b", age =20) + ).toDS() val udfWrapper = udf.register("nameConcatAge") { name, age -> "$name-$age" @@ -140,23 +143,22 @@ class UDFRegisterTest : ShouldSpec({ } } - // get the same exception with: https://forums.databricks.com/questions/13361/how-do-i-create-a-udf-in-java-which-return-a-compl.html -// context("udf return data class") { -// withSpark(logLevel = SparkLogLevel.DEBUG) { -// should("return NormalClass") { -// listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2") -// udf.register("toNormalClass") { a, b -> -// NormalClass(a,b) -// } -// spark.sql("select toNormalClass(first, second) from test2").show() -// } -// } -// } + context("udf return data class") { + withSpark(logLevel = SparkLogLevel.DEBUG) { + should("return NormalClass") { + listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2") + udf.register("toNormalClass") { a, b -> + NormalClass(b, a) + } + spark.sql("select toNormalClass(first, second) from test2").show() + } + } + } } }) data class NormalClass( - val name: String, - val age: Int + val age: Int, + val name: String ) From e98f6d1ec9d2a333fed310f9e8fe4ceb5caad2d4 Mon Sep 17 00:00:00 2001 From: can wang Date: Sun, 5 Sep 2021 13:51:43 +0800 Subject: [PATCH 17/18] fix code inspection issue --- .../kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 2 ++ .../kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt | 2 ++ 2 files changed, 4 insertions(+) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 9afa5a65..044ec399 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -26,8 +26,10 @@ import org.junit.jupiter.api.assertThrows import scala.collection.JavaConversions import scala.collection.mutable.WrappedArray +@Suppress("unused") private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this) +@Suppress("unused") class UDFRegisterTest : ShouldSpec({ context("org.jetbrains.kotlinx.spark.api.UDFRegister") { context("the function checkForValidType") { diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt index 84bbd55e..1926be42 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -26,8 +26,10 @@ import org.junit.jupiter.api.assertThrows import scala.collection.JavaConverters import scala.collection.mutable.WrappedArray +@Suppress("unused") private fun scala.collection.Iterable.asIterable(): Iterable = JavaConverters.asJavaIterable(this) +@Suppress("unused") class UDFRegisterTest : ShouldSpec({ context("org.jetbrains.kotlinx.spark.api.UDFRegister") { context("the function checkForValidType") { From 56cd2807840676e1531831dc3c6395f5d8bd3d2e Mon Sep 17 00:00:00 2001 From: Pasha Finkelshteyn Date: Sun, 5 Sep 2021 23:14:39 +0300 Subject: [PATCH 18/18] Adds suppre unused --- .../main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt index 6dc19d58..a6130f55 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -17,7 +17,7 @@ * limitations under the License. * =LICENSEEND= */ -@file:Suppress("DuplicatedCode") +@file:Suppress("DuplicatedCode", "unused") package org.jetbrains.kotlinx.spark.api