diff --git a/core/2.4/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala b/core/2.4/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala index 7f0e6c87..89c2eab9 100644 --- a/core/2.4/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala +++ b/core/2.4/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala @@ -19,8 +19,9 @@ */ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.{KotlinReflection, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.types.{DataType, Metadata, StructField, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, StructField, StructType, UserDefinedType} trait DataTypeWithClass { @@ -167,6 +168,8 @@ case class KSimpleTypeWrapper(dt: DataType, cls: Class[_], nullable: Boolean) ex override def defaultSize: Int = dt.defaultSize + override def toString: String = s"KSTW(${dt.toString})" + override private[spark] def asNullable = dt.asNullable } diff --git a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala new file mode 100644 index 00000000..d4ae2a83 --- /dev/null +++ b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -0,0 +1,336 @@ +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataTypeWithClass, KSimpleTypeWrapper, Row} +import org.apache.spark.unsafe.types.UTF8String + +import scala.annotation.tailrec +import scala.collection.Map +import scala.reflect.ClassTag + +/** + * A factory for constructing encoders that convert external row to/from the Spark SQL + * internal binary representation. + * + * The following is a mapping between Spark SQL types and its allowed external types: + * {{{ + * BooleanType -> java.lang.Boolean + * ByteType -> java.lang.Byte + * ShortType -> java.lang.Short + * IntegerType -> java.lang.Integer + * FloatType -> java.lang.Float + * DoubleType -> java.lang.Double + * StringType -> String + * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal + * + * DateType -> java.sql.Date + * TimestampType -> java.sql.Timestamp + * + * BinaryType -> byte array + * ArrayType -> scala.collection.Seq or Array + * MapType -> scala.collection.Map + * StructType -> org.apache.spark.sql.Row + * }}} + */ +object RowEncoder { + def apply(schema: StructType): ExpressionEncoder[Row] = { + val cls = classOf[Row] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val updatedSchema = schema.copy( + fields = schema.fields.map(f => f.dataType match { + case kstw: DataTypeWithClass => f.copy(dataType = kstw.dt, nullable = kstw.nullable) + case _ => f + }) + ) + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), updatedSchema) + val deserializer = deserializerFor(updatedSchema) + new ExpressionEncoder[Row]( + updatedSchema, + flat = false, + serializer.asInstanceOf[CreateNamedStruct].flatten, + deserializer, + ClassTag(cls)) + } + + private def serializerFor( + inputObject: Expression, + inputType: DataType): Expression = inputType match { + case dt if ScalaReflection.isNativeType(dt) => inputObject + + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) + + case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } + val obj = NewInstance( + udtClass, + Nil, + dataType = ObjectType(udtClass), false) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) + + case TimestampType => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil, + returnNullable = false) + + case DateType => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil, + returnNullable = false) + + case d: DecimalType => + CheckOverflow(StaticInvoke( + Decimal.getClass, + d, + "fromDecimal", + inputObject :: Nil, + returnNullable = false), d) + + case StringType => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + + case t@ArrayType(et, containsNull) => + et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + StaticInvoke( + classOf[ArrayData], + t, + "toArrayData", + inputObject :: Nil, + returnNullable = false) + + case _ => MapObjects( + element => { + val value = serializerFor(ValidateExternalType(element, et), et) + if (!containsNull) { + AssertNotNull(value, Seq.empty) + } else { + value + } + }, + inputObject, + ObjectType(classOf[Object])) + } + + case t@MapType(kt, vt, valueNullable) => + val keys = + Invoke( + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) + + val nonNullOutput = NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = t, + propagateNull = false) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } + + case StructType(fields) => + val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => + val dataType = field.dataType + val fieldValue = serializerFor( + ValidateExternalType( + GetExternalRowField(inputObject, index, field.name), + dataType), + dataType) + val convertedField = if (field.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), + Literal.create(null, dataType), + fieldValue + ) + } else { + fieldValue + } + Literal(field.name) :: convertedField :: Nil + }) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput + } + } + + /** + * Returns the `DataType` that can be used when generating code that converts input data + * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned + * by this function can be more permissive since multiple external types may map to a single + * internal type. For example, for an input with DecimalType in external row, its external types + * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or + * `org.apache.spark.sql.types.Decimal`. + */ + def externalDataTypeForInput(dt: DataType): DataType = dt match { + case _ => dt match { + case dtwc: DataTypeWithClass => dtwc.dt + // In order to support both Decimal and java/scala BigDecimal in external row, we make this + // as java.lang.Object. + case _: DecimalType => ObjectType(classOf[Object]) + // In order to support both Array and Seq in external row, we make this as java.lang.Object. + case _: ArrayType => ObjectType(classOf[Object]) + case _ => externalDataTypeFor(dt) + } + } + + @tailrec + def externalDataTypeFor(dt: DataType): DataType = dt match { + case kstw: DataTypeWithClass => externalDataTypeFor(kstw.dt) + case _ if ScalaReflection.isNativeType(dt) => dt + case TimestampType => ObjectType(classOf[java.sql.Timestamp]) + case DateType => ObjectType(classOf[java.sql.Date]) + case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) + case StringType => ObjectType(classOf[java.lang.String]) + case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) + case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) + case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) + } + + private def deserializerFor(schema: StructType): Expression = { + val fields = schema.zipWithIndex.map { case (f, i) => + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + deserializerFor(GetColumnByOrdinal(i, dt)) + } + + CreateExternalRow(fields, schema) + } + + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { + case dt if ScalaReflection.isNativeType(dt) => input + case kstw: DataTypeWithClass => deserializerFor(input, kstw.dt) + + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) + + case udt: UserDefinedType[_] => + val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) + val udtClass: Class[_] = if (annotation != null) { + annotation.udt() + } else { + UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse { + throw new SparkException(s"${udt.userClass.getName} is not annotated with " + + "SQLUserDefinedType nor registered with UDTRegistration.}") + } + } + val obj = NewInstance( + udtClass, + Nil, + dataType = ObjectType(udtClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + + case TimestampType => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + input :: Nil, + returnNullable = false) + + case DateType => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + input :: Nil, + returnNullable = false) + + case _: DecimalType => + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) + + case StringType => + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) + + case ArrayType(et, nullable) => + val arrayData = + Invoke( + MapObjects(deserializerFor(_), input, et), + "array", + ObjectType(classOf[Array[_]]), returnNullable = false) + StaticInvoke( + scala.collection.mutable.WrappedArray.getClass, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil, + returnNullable = false) + + case MapType(kt, vt, valueNullable) => + val keyArrayType = ArrayType(kt, false) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) + + val valueArrayType = ArrayType(vt, valueNullable) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil, + returnNullable = false) + + case schema@StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + If( + Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, externalDataTypeFor(f.dataType match { + case kstw: DataTypeWithClass => kstw.dt + case o => o + })), + deserializerFor(GetStructField(input, i))) + } + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields, schema)) + } +} diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt index 08ffee12..7f488e2c 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt @@ -20,6 +20,7 @@ package org.jetbrains.kotlinx.spark.api import org.apache.spark.sql.SparkSession.Builder +import org.apache.spark.sql.UDFRegistration import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR /** @@ -78,4 +79,6 @@ inline class KSparkSession(val spark: SparkSession) { inline fun List.toDS() = toDS(spark) inline fun Array.toDS() = spark.dsOf(*this) inline fun dsOf(vararg arg: T) = spark.dsOf(*arg) + val udf: UDFRegistration get() = spark.udf() + } diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt new file mode 100644 index 00000000..e07a0deb --- /dev/null +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt @@ -0,0 +1,797 @@ +@file:Suppress("DuplicatedCode") + +package org.jetbrains.kotlinx.spark.api + +import org.apache.spark.sql.Column +import org.apache.spark.sql.UDFRegistration +import org.apache.spark.sql.api.java.* +import org.apache.spark.sql.functions +import scala.collection.mutable.WrappedArray +import kotlin.reflect.KClass +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.typeOf + +/** + * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException] + */ +@PublishedApi +internal fun KClass<*>.checkForValidType(parameterName: String) { + if (this == String::class || isSubclassOf(WrappedArray::class)) return // Most of the time we need strings or WrappedArrays + if (isSubclassOf(Iterable::class) || java.isArray + || isSubclassOf(Map::class) || isSubclassOf(Array::class) + || isSubclassOf(ByteArray::class) || isSubclassOf(CharArray::class) + || isSubclassOf(ShortArray::class) || isSubclassOf(IntArray::class) + || isSubclassOf(LongArray::class) || isSubclassOf(FloatArray::class) + || isSubclassOf(DoubleArray::class) || isSubclassOf(BooleanArray::class) + ) { + throw TypeOfUDFParameterNotSupportedException(this, parameterName) + } +} + +/** + * An exception thrown when the UDF is generated with illegal types for the parameters + */ +class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: String) : IllegalArgumentException( + "Parameter $parameterName is subclass of ${kClass.qualifiedName}. If you need to process an array use ${WrappedArray::class.qualifiedName}." +) + +/** + * A wrapper for an UDF with 0 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper0(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(): Column { + return functions.callUDF(udfName) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 { + register(name, UDF0(func), schema(typeOf())) + return UDFWrapper0(name) +} + +/** + * A wrapper for an UDF with 1 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper1(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column): Column { + return functions.callUDF(udfName, param0) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 { + T0::class.checkForValidType("T0") + register(name, UDF1(func), schema(typeOf())) + return UDFWrapper1(name) +} + +/** + * A wrapper for an UDF with 2 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper2(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column): Column { + return functions.callUDF(udfName, param0, param1) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1) -> R): UDFWrapper2 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + register(name, UDF2(func), schema(typeOf())) + return UDFWrapper2(name) +} + +/** + * A wrapper for an UDF with 3 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper3(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column): Column { + return functions.callUDF(udfName, param0, param1, param2) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2) -> R): UDFWrapper3 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + register(name, UDF3(func), schema(typeOf())) + return UDFWrapper3(name) +} + +/** + * A wrapper for an UDF with 4 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper4(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3) -> R): UDFWrapper4 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + register(name, UDF4(func), schema(typeOf())) + return UDFWrapper4(name) +} + +/** + * A wrapper for an UDF with 5 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper5(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4) -> R): UDFWrapper5 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + register(name, UDF5(func), schema(typeOf())) + return UDFWrapper5(name) +} + +/** + * A wrapper for an UDF with 6 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper6(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5) -> R): UDFWrapper6 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + register(name, UDF6(func), schema(typeOf())) + return UDFWrapper6(name) +} + +/** + * A wrapper for an UDF with 7 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper7(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6) -> R): UDFWrapper7 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + register(name, UDF7(func), schema(typeOf())) + return UDFWrapper7(name) +} + +/** + * A wrapper for an UDF with 8 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper8(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7) -> R): UDFWrapper8 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + register(name, UDF8(func), schema(typeOf())) + return UDFWrapper8(name) +} + +/** + * A wrapper for an UDF with 9 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper9(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8) -> R): UDFWrapper9 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + register(name, UDF9(func), schema(typeOf())) + return UDFWrapper9(name) +} + +/** + * A wrapper for an UDF with 10 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper10(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> R): UDFWrapper10 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + register(name, UDF10(func), schema(typeOf())) + return UDFWrapper10(name) +} + +/** + * A wrapper for an UDF with 11 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper11(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> R): UDFWrapper11 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + register(name, UDF11(func), schema(typeOf())) + return UDFWrapper11(name) +} + +/** + * A wrapper for an UDF with 12 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper12(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> R): UDFWrapper12 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + register(name, UDF12(func), schema(typeOf())) + return UDFWrapper12(name) +} + +/** + * A wrapper for an UDF with 13 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper13(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> R): UDFWrapper13 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + register(name, UDF13(func), schema(typeOf())) + return UDFWrapper13(name) +} + +/** + * A wrapper for an UDF with 14 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper14(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) -> R): UDFWrapper14 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + register(name, UDF14(func), schema(typeOf())) + return UDFWrapper14(name) +} + +/** + * A wrapper for an UDF with 15 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper15(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) -> R): UDFWrapper15 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + register(name, UDF15(func), schema(typeOf())) + return UDFWrapper15(name) +} + +/** + * A wrapper for an UDF with 16 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper16(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15) -> R): UDFWrapper16 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + register(name, UDF16(func), schema(typeOf())) + return UDFWrapper16(name) +} + +/** + * A wrapper for an UDF with 17 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper17(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> R): UDFWrapper17 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + register(name, UDF17(func), schema(typeOf())) + return UDFWrapper17(name) +} + +/** + * A wrapper for an UDF with 18 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper18(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column, param17: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16, param17) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17) -> R): UDFWrapper18 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + register(name, UDF18(func), schema(typeOf())) + return UDFWrapper18(name) +} + +/** + * A wrapper for an UDF with 19 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper19(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column, param17: Column, param18: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16, param17, param18) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18) -> R): UDFWrapper19 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + register(name, UDF19(func), schema(typeOf())) + return UDFWrapper19(name) +} + +/** + * A wrapper for an UDF with 20 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper20(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column, param17: Column, param18: Column, param19: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16, param17, param18, param19) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19) -> R): UDFWrapper20 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + register(name, UDF20(func), schema(typeOf())) + return UDFWrapper20(name) +} + +/** + * A wrapper for an UDF with 21 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper21(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column, param17: Column, param18: Column, param19: Column, param20: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16, param17, param18, param19, param20) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20) -> R): UDFWrapper21 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + register(name, UDF21(func), schema(typeOf())) + return UDFWrapper21(name) +} + +/** + * A wrapper for an UDF with 22 arguments. + * @property udfName the name of the UDF + */ +class UDFWrapper22(private val udfName: String) { + /** + * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns. + */ + operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column, param5: Column, param6: Column, param7: Column, param8: Column, param9: Column, param10: Column, param11: Column, param12: Column, param13: Column, param14: Column, param15: Column, param16: Column, param17: Column, param18: Column, param19: Column, param20: Column, param21: Column): Column { + return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8, param9, param10, param11, param12, param13, param14, param15, param16, param17, param18, param19, param20, param21) + } +} + +/** + * Registers the [func] with its [name] in [this] + */ +@OptIn(ExperimentalStdlibApi::class) +inline fun UDFRegistration.register(name: String, noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21) -> R): UDFWrapper22 { + T0::class.checkForValidType("T0") + T1::class.checkForValidType("T1") + T2::class.checkForValidType("T2") + T3::class.checkForValidType("T3") + T4::class.checkForValidType("T4") + T5::class.checkForValidType("T5") + T6::class.checkForValidType("T6") + T7::class.checkForValidType("T7") + T8::class.checkForValidType("T8") + T9::class.checkForValidType("T9") + T10::class.checkForValidType("T10") + T11::class.checkForValidType("T11") + T12::class.checkForValidType("T12") + T13::class.checkForValidType("T13") + T14::class.checkForValidType("T14") + T15::class.checkForValidType("T15") + T16::class.checkForValidType("T16") + T17::class.checkForValidType("T17") + T18::class.checkForValidType("T18") + T19::class.checkForValidType("T19") + T20::class.checkForValidType("T20") + T21::class.checkForValidType("T21") + register(name, UDF22(func), schema(typeOf())) + return UDFWrapper22(name) +} + + diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt new file mode 100644 index 00000000..6e32278f --- /dev/null +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt @@ -0,0 +1,122 @@ +package org.jetbrains.kotlinx.spark.api + +import ch.tutteli.atrium.api.fluent.en_GB.* +import ch.tutteli.atrium.verbs.expect +import ch.tutteli.atrium.domain.builders.migration.asExpect +import io.kotest.core.spec.style.ShouldSpec +import org.apache.spark.sql.RowFactory +import org.apache.spark.sql.types.DataTypes +import org.junit.jupiter.api.assertThrows +import scala.collection.JavaConversions +import scala.collection.mutable.WrappedArray + +private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this) + +class UDFRegisterTest : ShouldSpec({ + context("org.jetbrains.kotlinx.spark.api.UDFRegister") { + context("the function checkForValidType") { + val invalidTypes = listOf( + Array::class, + Iterable::class, + List::class, + MutableList::class, + ByteArray::class, + CharArray::class, + ShortArray::class, + IntArray::class, + LongArray::class, + FloatArray::class, + DoubleArray::class, + BooleanArray::class, + Map::class, + MutableMap::class, + Set::class, + MutableSet::class, + arrayOf("")::class, + listOf("")::class, + setOf("")::class, + mapOf("" to "")::class, + mutableListOf("")::class, + mutableSetOf("")::class, + mutableMapOf("" to "")::class, + ) + invalidTypes.forEachIndexed { index, invalidType -> + should("$index: throw an ${TypeOfUDFParameterNotSupportedException::class.simpleName} when encountering ${invalidType.qualifiedName}") { + assertThrows { + invalidType.checkForValidType("test") + } + } + } + } + + context("the register-function") { + withSpark { + + should("fail when using a simple kotlin.Array") { + assertThrows { + udf.register("shouldFail") { array: Array -> + array.joinToString(" ") + } + } + } + + should("succeed when using a WrappedArray") { + udf.register("shouldSucceed") { array: WrappedArray -> + array.asIterable().joinToString(" ") + } + } + } + } + + context("calling the UDF-Wrapper") { + withSpark(logLevel = SparkLogLevel.DEBUG) { + should("succeed when using the right number of arguments") { + val schema = DataTypes.createStructType( + listOf( + DataTypes.createStructField("textArray", DataTypes.createArrayType(DataTypes.StringType), false), + DataTypes.createStructField("id", DataTypes.StringType, false) + ) + ) + + val rows = listOf( + RowFactory.create(arrayOf("a", "b", "c"), "1"), + RowFactory.create(arrayOf("d", "e", "f"), "2"), + RowFactory.create(arrayOf("g", "h", "i"), "3"), + ) + + val testData = spark.createDataFrame(rows, schema) + + val stringArrayMerger = udf.register, String>("stringArrayMerger") { + it.asIterable().joinToString(" ") + } + + val newData = testData.withColumn("text", stringArrayMerger(testData.col("textArray"))) + + newData.select("text").collectAsList().zip(newData.select("textArray").collectAsList()).forEach { (text, textArray) -> + assert(text.getString(0) == textArray.getList(0).joinToString(" ")) + } + } +// should("also work with datasets") { +// val ds = listOf("a" to 1, "b" to 2).toDS() +// val stringIntDiff = udf.register>("stringIntDiff") { a, b -> +// c(a[0].toInt() - b) +// } +// val lst = ds.withColumn("new", stringIntDiff(ds.col("first"), ds.col("second"))) +// .select("new") +// .collectAsList() +// +//// val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList() +//// expect(result).asExpect().contains.inOrder.only.values(96, 96) +// } + should("also work with datasets") { + listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1") + udf.register("stringIntDiff") { a, b -> + a[0].toInt() - b + } + spark.sql("select stringIntDiff(first, second) from test1").show() + + } + } + } + } +}) \ No newline at end of file