diff --git a/core/3.0/pom_2.12.xml b/core/3.0/pom_2.12.xml
index b06ca56e..c3c2e972 100644
--- a/core/3.0/pom_2.12.xml
+++ b/core/3.0/pom_2.12.xml
@@ -18,7 +18,10 @@
scala-library
${scala.version}
-
+
+ org.jetbrains.kotlin
+ kotlin-reflect
+
org.apache.spark
diff --git a/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
new file mode 100644
index 00000000..f618cf04
--- /dev/null
+++ b/core/3.0/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -0,0 +1,492 @@
+package org.apache.spark.sql.catalyst
+
+import kotlin.jvm.JvmClassMappingKt
+import kotlin.reflect.{KClass, KFunction, KProperty1}
+import kotlin.reflect.full.KClasses
+
+import java.lang.{Iterable => JavaIterable}
+import java.math.{BigDecimal => JavaBigDecimal}
+import java.math.{BigInteger => JavaBigInteger}
+import java.sql.{Date, Timestamp}
+import java.time.{Instant, LocalDate}
+import java.util.{Map => JavaMap}
+import javax.annotation.Nullable
+import scala.language.existentials
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Functions to convert Scala types to Catalyst types and vice versa.
+ */
+object CatalystTypeConverters {
+ // The Predef.Map is scala.collection.immutable.Map.
+ // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+
+ import scala.collection.Map
+
+ private[sql] def isPrimitive(dataType: DataType): Boolean = {
+ dataType match {
+ case BooleanType => true
+ case ByteType => true
+ case ShortType => true
+ case IntegerType => true
+ case LongType => true
+ case FloatType => true
+ case DoubleType => true
+ case _ => false
+ }
+ }
+
+ private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
+ val converter = dataType match {
+ case udt: UserDefinedType[_] => UDTConverter(udt)
+ case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
+ case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
+ case structType: StructType => StructConverter(structType)
+ case StringType => StringConverter
+ case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
+ case DateType => DateConverter
+ case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
+ case TimestampType => TimestampConverter
+ case dt: DecimalType => new DecimalConverter(dt)
+ case BooleanType => BooleanConverter
+ case ByteType => ByteConverter
+ case ShortType => ShortConverter
+ case IntegerType => IntConverter
+ case LongType => LongConverter
+ case FloatType => FloatConverter
+ case DoubleType => DoubleConverter
+ case dataType: DataType => IdentityConverter(dataType)
+ }
+ converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
+ }
+
+ /**
+ * Converts a Scala type to its Catalyst equivalent (and vice versa).
+ *
+ * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst.
+ * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala.
+ * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type.
+ */
+ private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType]
+ extends Serializable {
+
+ /**
+ * Converts a Scala type to its Catalyst equivalent while automatically handling nulls
+ * and Options.
+ */
+ final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = {
+ if (maybeScalaValue == null) {
+ null.asInstanceOf[CatalystType]
+ } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) {
+ val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]]
+ if (opt.isDefined) {
+ toCatalystImpl(opt.get)
+ } else {
+ null.asInstanceOf[CatalystType]
+ }
+ } else {
+ toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType])
+ }
+ }
+
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ */
+ final def toScala(row: InternalRow, column: Int): ScalaOutputType = {
+ if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column)
+ }
+
+ /**
+ * Convert a Catalyst value to its Scala equivalent.
+ */
+ def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType
+
+ /**
+ * Converts a Scala value to its Catalyst equivalent.
+ *
+ * @param scalaValue the Scala value, guaranteed not to be null.
+ * @return the Catalyst value.
+ */
+ protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType
+
+ /**
+ * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
+ * This method will only be called on non-null columns.
+ */
+ protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType
+ }
+
+ private case class IdentityConverter(dataType: DataType)
+ extends CatalystTypeConverter[Any, Any, Any] {
+ override def toCatalystImpl(scalaValue: Any): Any = scalaValue
+
+ override def toScala(catalystValue: Any): Any = catalystValue
+
+ override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
+ }
+
+ private case class UDTConverter[A >: Null](
+ udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
+ // toCatalyst (it calls toCatalystImpl) will do null check.
+ override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)
+
+ override def toScala(catalystValue: Any): A = {
+ if (catalystValue == null) null else udt.deserialize(catalystValue)
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): A =
+ toScala(row.get(column, udt.sqlType))
+ }
+
+ /** Converter for arrays, sequences, and Java iterables. */
+ private case class ArrayConverter(
+ elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
+
+ private[this] val elementConverter = getConverterForType(elementType)
+
+ override def toCatalystImpl(scalaValue: Any): ArrayData = {
+ scalaValue match {
+ case a: Array[_] =>
+ new GenericArrayData(a.map(elementConverter.toCatalyst))
+ case s: Seq[_] =>
+ new GenericArrayData(s.map(elementConverter.toCatalyst).toArray)
+ case i: JavaIterable[_] =>
+ val iter = i.iterator
+ val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any]
+ while (iter.hasNext) {
+ val item = iter.next()
+ convertedIterable += elementConverter.toCatalyst(item)
+ }
+ new GenericArrayData(convertedIterable.toArray)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to an array of ${elementType.catalogString}")
+ }
+ }
+
+ override def toScala(catalystValue: ArrayData): Seq[Any] = {
+ if (catalystValue == null) {
+ null
+ } else if (isPrimitive(elementType)) {
+ catalystValue.toArray[Any](elementType)
+ } else {
+ val result = new Array[Any](catalystValue.numElements())
+ catalystValue.foreach(elementType, (i, e) => {
+ result(i) = elementConverter.toScala(e)
+ })
+ result
+ }
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
+ toScala(row.getArray(column))
+ }
+
+ private case class MapConverter(
+ keyType: DataType,
+ valueType: DataType)
+ extends CatalystTypeConverter[Any, Map[Any, Any], MapData] {
+
+ private[this] val keyConverter = getConverterForType(keyType)
+ private[this] val valueConverter = getConverterForType(valueType)
+
+ override def toCatalystImpl(scalaValue: Any): MapData = {
+ val keyFunction = (k: Any) => keyConverter.toCatalyst(k)
+ val valueFunction = (k: Any) => valueConverter.toCatalyst(k)
+
+ scalaValue match {
+ case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction)
+ case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + "cannot be converted to a map type with "
+ + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})")
+ }
+ }
+
+ override def toScala(catalystValue: MapData): Map[Any, Any] = {
+ if (catalystValue == null) {
+ null
+ } else {
+ val keys = catalystValue.keyArray().toArray[Any](keyType)
+ val values = catalystValue.valueArray().toArray[Any](valueType)
+ val convertedKeys =
+ if (isPrimitive(keyType)) keys else keys.map(keyConverter.toScala)
+ val convertedValues =
+ if (isPrimitive(valueType)) values else values.map(valueConverter.toScala)
+
+ convertedKeys.zip(convertedValues).toMap
+ }
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] =
+ toScala(row.getMap(column))
+ }
+
+ private case class StructConverter(
+ structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] {
+
+ private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
+
+ override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match {
+ case row: Row =>
+ val ar = new Array[Any](row.size)
+ var idx = 0
+ while (idx < row.size) {
+ ar(idx) = converters(idx).toCatalyst(row(idx))
+ idx += 1
+ }
+ new GenericInternalRow(ar)
+
+ case p: Product =>
+ val ar = new Array[Any](structType.size)
+ val iter = p.productIterator
+ var idx = 0
+ while (idx < structType.size) {
+ ar(idx) = converters(idx).toCatalyst(iter.next())
+ idx += 1
+ }
+ new GenericInternalRow(ar)
+
+ case ktDataClass: Any if JvmClassMappingKt.getKotlinClass(ktDataClass.getClass).isData =>
+ import scala.collection.JavaConverters._
+ val klass: KClass[Any] = JvmClassMappingKt.getKotlinClass(ktDataClass.getClass).asInstanceOf[KClass[Any]]
+ val iter: Iterator[KProperty1[Any,_]] = KClasses.getDeclaredMemberProperties(klass).iterator().asScala
+ val ar = new Array[Any](structType.size)
+ var idx = 0
+ while (idx < structType.size) {
+ ar(idx) = converters(idx).toCatalyst(iter.next().get(ktDataClass))
+ idx += 1
+ }
+ new GenericInternalRow(ar)
+
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to ${structType.catalogString}")
+ }
+
+ override def toScala(row: InternalRow): Row = {
+ if (row == null) {
+ null
+ } else {
+ val ar = new Array[Any](row.numFields)
+ var idx = 0
+ while (idx < row.numFields) {
+ ar(idx) = converters(idx).toScala(row, idx)
+ idx += 1
+ }
+ new GenericRowWithSchema(ar, structType)
+ }
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): Row =
+ toScala(row.getStruct(column, structType.size))
+ }
+
+ private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
+ override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
+ case str: String => UTF8String.fromString(str)
+ case utf8: UTF8String => utf8
+ case chr: Char => UTF8String.fromString(chr.toString)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to the string type")
+ }
+
+ override def toScala(catalystValue: UTF8String): String =
+ if (catalystValue == null) null else catalystValue.toString
+
+ override def toScalaImpl(row: InternalRow, column: Int): String =
+ row.getUTF8String(column).toString
+ }
+
+ private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
+ override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
+
+ override def toScala(catalystValue: Any): Date =
+ if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
+
+ override def toScalaImpl(row: InternalRow, column: Int): Date =
+ DateTimeUtils.toJavaDate(row.getInt(column))
+ }
+
+ private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] {
+ override def toCatalystImpl(scalaValue: LocalDate): Int = {
+ DateTimeUtils.localDateToDays(scalaValue)
+ }
+
+ override def toScala(catalystValue: Any): LocalDate = {
+ if (catalystValue == null) null
+ else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int])
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): LocalDate =
+ DateTimeUtils.daysToLocalDate(row.getInt(column))
+ }
+
+ private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
+ override def toCatalystImpl(scalaValue: Timestamp): Long =
+ DateTimeUtils.fromJavaTimestamp(scalaValue)
+
+ override def toScala(catalystValue: Any): Timestamp =
+ if (catalystValue == null) null
+ else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
+
+ override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
+ DateTimeUtils.toJavaTimestamp(row.getLong(column))
+ }
+
+ private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] {
+ override def toCatalystImpl(scalaValue: Instant): Long =
+ DateTimeUtils.instantToMicros(scalaValue)
+
+ override def toScala(catalystValue: Any): Instant =
+ if (catalystValue == null) null
+ else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long])
+
+ override def toScalaImpl(row: InternalRow, column: Int): Instant =
+ DateTimeUtils.microsToInstant(row.getLong(column))
+ }
+
+ private class DecimalConverter(dataType: DecimalType)
+ extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
+
+ private val nullOnOverflow = !SQLConf.get.ansiEnabled
+
+ override def toCatalystImpl(scalaValue: Any): Decimal = {
+ val decimal = scalaValue match {
+ case d: BigDecimal => Decimal(d)
+ case d: JavaBigDecimal => Decimal(d)
+ case d: JavaBigInteger => Decimal(d)
+ case d: Decimal => d
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to ${dataType.catalogString}")
+ }
+ decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
+ }
+
+ override def toScala(catalystValue: Decimal): JavaBigDecimal = {
+ if (catalystValue == null) null
+ else catalystValue.toJavaBigDecimal
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
+ row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
+ }
+
+ private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
+ final override def toScala(catalystValue: Any): Any = catalystValue
+
+ final override def toCatalystImpl(scalaValue: T): Any = scalaValue
+ }
+
+ private object BooleanConverter extends PrimitiveConverter[Boolean] {
+ override def toScalaImpl(row: InternalRow, column: Int): Boolean = row.getBoolean(column)
+ }
+
+ private object ByteConverter extends PrimitiveConverter[Byte] {
+ override def toScalaImpl(row: InternalRow, column: Int): Byte = row.getByte(column)
+ }
+
+ private object ShortConverter extends PrimitiveConverter[Short] {
+ override def toScalaImpl(row: InternalRow, column: Int): Short = row.getShort(column)
+ }
+
+ private object IntConverter extends PrimitiveConverter[Int] {
+ override def toScalaImpl(row: InternalRow, column: Int): Int = row.getInt(column)
+ }
+
+ private object LongConverter extends PrimitiveConverter[Long] {
+ override def toScalaImpl(row: InternalRow, column: Int): Long = row.getLong(column)
+ }
+
+ private object FloatConverter extends PrimitiveConverter[Float] {
+ override def toScalaImpl(row: InternalRow, column: Int): Float = row.getFloat(column)
+ }
+
+ private object DoubleConverter extends PrimitiveConverter[Double] {
+ override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column)
+ }
+
+ /**
+ * Creates a converter function that will convert Scala objects to the specified Catalyst type.
+ * Typical use case would be converting a collection of rows that have the same schema. You will
+ * call this function once to get a converter, and apply it to every row.
+ */
+ def createToCatalystConverter(dataType: DataType): Any => Any = {
+ if (isPrimitive(dataType)) {
+ // Although the `else` branch here is capable of handling inbound conversion of primitives,
+ // we add some special-case handling for those types here. The motivation for this relates to
+ // Java method invocation costs: if we have rows that consist entirely of primitive columns,
+ // then returning the same conversion function for all of the columns means that the call site
+ // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in
+ // a measurable performance impact. Note that this optimization will be unnecessary if we
+ // use code generation to construct Scala Row -> Catalyst Row converters.
+ def convert(maybeScalaValue: Any): Any = {
+ if (maybeScalaValue.isInstanceOf[Option[Any]]) {
+ maybeScalaValue.asInstanceOf[Option[Any]].orNull
+ } else {
+ maybeScalaValue
+ }
+ }
+
+ convert
+ } else {
+ getConverterForType(dataType).toCatalyst
+ }
+ }
+
+ /**
+ * Creates a converter function that will convert Catalyst types to Scala type.
+ * Typical use case would be converting a collection of rows that have the same schema. You will
+ * call this function once to get a converter, and apply it to every row.
+ */
+ def createToScalaConverter(dataType: DataType): Any => Any = {
+ if (isPrimitive(dataType)) {
+ identity
+ } else {
+ getConverterForType(dataType).toScala
+ }
+ }
+
+ /**
+ * Converts Scala objects to Catalyst rows / types.
+ *
+ * Note: This should be called before do evaluation on Row
+ * (It does not support UDT)
+ * This is used to create an RDD or test results with correct types for Catalyst.
+ */
+ def convertToCatalyst(a: Any): Any = a match {
+ case s: String => StringConverter.toCatalyst(s)
+ case d: Date => DateConverter.toCatalyst(d)
+ case ld: LocalDate => LocalDateConverter.toCatalyst(ld)
+ case t: Timestamp => TimestampConverter.toCatalyst(t)
+ case i: Instant => InstantConverter.toCatalyst(i)
+ case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
+ case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
+ case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
+ case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
+ case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
+ case map: Map[_, _] =>
+ ArrayBasedMapData(
+ map,
+ (key: Any) => convertToCatalyst(key),
+ (value: Any) => convertToCatalyst(value))
+ case other => other
+ }
+
+ /**
+ * Converts Catalyst types used internally in rows to standard Scala types
+ * This method is slow, and for batch conversion you should be using converter
+ * produced by createToScalaConverter.
+ */
+ def convertToScala(catalystValue: Any, dataType: DataType): Any = {
+ createToScalaConverter(dataType)(catalystValue)
+ }
+}
diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
index d5d6aa2c..3ef0b177 100644
--- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
+++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
@@ -20,6 +20,7 @@
package org.jetbrains.kotlinx.spark.api
import org.apache.spark.sql.SparkSession.Builder
+import org.apache.spark.sql.UDFRegistration
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
/**
@@ -78,4 +79,5 @@ inline class KSparkSession(val spark: SparkSession) {
inline fun List.toDS() = toDS(spark)
inline fun Array.toDS() = spark.dsOf(*this)
inline fun dsOf(vararg arg: T) = spark.dsOf(*arg)
+ val udf: UDFRegistration get() = spark.udf()
}
diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt
new file mode 100644
index 00000000..a6130f55
--- /dev/null
+++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt
@@ -0,0 +1,1375 @@
+/*-
+ * =LICENSE=
+ * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
+ * ----------
+ * Copyright (C) 2019 - 2021 JetBrains
+ * ----------
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =LICENSEEND=
+ */
+@file:Suppress("DuplicatedCode", "unused")
+
+package org.jetbrains.kotlinx.spark.api
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.DataTypeWithClass
+import org.apache.spark.sql.UDFRegistration
+import org.apache.spark.sql.api.java.*
+import org.apache.spark.sql.functions
+import org.apache.spark.sql.types.DataType
+import scala.collection.mutable.WrappedArray
+import kotlin.reflect.KClass
+import kotlin.reflect.full.isSubclassOf
+import kotlin.reflect.typeOf
+
+fun DataType.unWrapper(): DataType {
+ return when (this) {
+ is DataTypeWithClass -> DataType.fromJson(dt().json())
+ else -> this
+ }
+}
+
+/**
+ * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException]
+ */
+@PublishedApi
+internal fun KClass<*>.checkForValidType(parameterName: String) {
+ if (this == String::class || isSubclassOf(WrappedArray::class)) return // Most of the time we need strings or WrappedArrays
+ if (isSubclassOf(Iterable::class) || java.isArray
+ || isSubclassOf(Map::class) || isSubclassOf(Array::class)
+ || isSubclassOf(ByteArray::class) || isSubclassOf(CharArray::class)
+ || isSubclassOf(ShortArray::class) || isSubclassOf(IntArray::class)
+ || isSubclassOf(LongArray::class) || isSubclassOf(FloatArray::class)
+ || isSubclassOf(DoubleArray::class) || isSubclassOf(BooleanArray::class)
+ ) {
+ throw TypeOfUDFParameterNotSupportedException(this, parameterName)
+ }
+}
+
+/**
+ * An exception thrown when the UDF is generated with illegal types for the parameters
+ */
+class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: String) : IllegalArgumentException(
+ "Parameter $parameterName is subclass of ${kClass.qualifiedName}. If you need to process an array use ${WrappedArray::class.qualifiedName}."
+)
+
+/**
+ * A wrapper for an UDF with 0 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper0(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(): Column {
+ return functions.callUDF(udfName)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 {
+ register(name, UDF0(func), schema(typeOf()).unWrapper())
+ return UDFWrapper0(name)
+}
+
+/**
+ * A wrapper for an UDF with 1 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper1(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column): Column {
+ return functions.callUDF(udfName, param0)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 {
+ T0::class.checkForValidType("T0")
+ register(name, UDF1(func), schema(typeOf()).unWrapper())
+ return UDFWrapper1(name)
+}
+
+/**
+ * A wrapper for an UDF with 2 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper2(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column): Column {
+ return functions.callUDF(udfName, param0, param1)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1) -> R
+): UDFWrapper2 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ register(name, UDF2(func), schema(typeOf()).unWrapper())
+ return UDFWrapper2(name)
+}
+
+/**
+ * A wrapper for an UDF with 3 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper3(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2) -> R
+): UDFWrapper3 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ register(name, UDF3(func), schema(typeOf()).unWrapper())
+ return UDFWrapper3(name)
+}
+
+/**
+ * A wrapper for an UDF with 4 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper4(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3) -> R
+): UDFWrapper4 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ register(name, UDF4(func), schema(typeOf()).unWrapper())
+ return UDFWrapper4(name)
+}
+
+/**
+ * A wrapper for an UDF with 5 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper5(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4) -> R
+): UDFWrapper5 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ register(name, UDF5(func), schema(typeOf()).unWrapper())
+ return UDFWrapper5(name)
+}
+
+/**
+ * A wrapper for an UDF with 6 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper6(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5) -> R
+): UDFWrapper6 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ register(name, UDF6(func), schema(typeOf()).unWrapper())
+ return UDFWrapper6(name)
+}
+
+/**
+ * A wrapper for an UDF with 7 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper7(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6) -> R
+): UDFWrapper7 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ register(name, UDF7(func), schema(typeOf()).unWrapper())
+ return UDFWrapper7(name)
+}
+
+/**
+ * A wrapper for an UDF with 8 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper8(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7) -> R
+): UDFWrapper8 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ register(name, UDF8(func), schema(typeOf()).unWrapper())
+ return UDFWrapper8(name)
+}
+
+/**
+ * A wrapper for an UDF with 9 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper9(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8) -> R
+): UDFWrapper9 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ register(name, UDF9(func), schema(typeOf()).unWrapper())
+ return UDFWrapper9(name)
+}
+
+/**
+ * A wrapper for an UDF with 10 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper10(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> R
+): UDFWrapper10 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ register(name, UDF10(func), schema(typeOf()).unWrapper())
+ return UDFWrapper10(name)
+}
+
+/**
+ * A wrapper for an UDF with 11 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper11(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> R
+): UDFWrapper11 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ register(name, UDF11(func), schema(typeOf()).unWrapper())
+ return UDFWrapper11(name)
+}
+
+/**
+ * A wrapper for an UDF with 12 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper12(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> R
+): UDFWrapper12 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ register(name, UDF12(func), schema(typeOf()).unWrapper())
+ return UDFWrapper12(name)
+}
+
+/**
+ * A wrapper for an UDF with 13 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper13(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> R
+): UDFWrapper13 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ register(name, UDF13(func), schema(typeOf()).unWrapper())
+ return UDFWrapper13(name)
+}
+
+/**
+ * A wrapper for an UDF with 14 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper14(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) -> R
+): UDFWrapper14 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ register(name, UDF14(func), schema(typeOf()).unWrapper())
+ return UDFWrapper14(name)
+}
+
+/**
+ * A wrapper for an UDF with 15 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper15(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) -> R
+): UDFWrapper15 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ register(name, UDF15(func), schema(typeOf()).unWrapper())
+ return UDFWrapper15(name)
+}
+
+/**
+ * A wrapper for an UDF with 16 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper16(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15) -> R
+): UDFWrapper16 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ register(name, UDF16(func), schema(typeOf()).unWrapper())
+ return UDFWrapper16(name)
+}
+
+/**
+ * A wrapper for an UDF with 17 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper17(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> R
+): UDFWrapper17 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ register(name, UDF17(func), schema(typeOf()).unWrapper())
+ return UDFWrapper17(name)
+}
+
+/**
+ * A wrapper for an UDF with 18 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper18(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17) -> R
+): UDFWrapper18 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ register(name, UDF18(func), schema(typeOf()).unWrapper())
+ return UDFWrapper18(name)
+}
+
+/**
+ * A wrapper for an UDF with 19 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper19(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18) -> R
+): UDFWrapper19 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ register(name, UDF19(func), schema(typeOf()).unWrapper())
+ return UDFWrapper19(name)
+}
+
+/**
+ * A wrapper for an UDF with 20 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper20(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19) -> R
+): UDFWrapper20 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ register(name, UDF20(func), schema(typeOf()).unWrapper())
+ return UDFWrapper20(name)
+}
+
+/**
+ * A wrapper for an UDF with 21 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper21(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column,
+ param20: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19,
+ param20
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20) -> R
+): UDFWrapper21 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ T20::class.checkForValidType("T20")
+ register(name, UDF21(func), schema(typeOf()).unWrapper())
+ return UDFWrapper21(name)
+}
+
+/**
+ * A wrapper for an UDF with 22 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper22(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column,
+ param20: Column,
+ param21: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19,
+ param20,
+ param21
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21) -> R
+): UDFWrapper22 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ T20::class.checkForValidType("T20")
+ T21::class.checkForValidType("T21")
+ register(name, UDF22(func), schema(typeOf()).unWrapper())
+ return UDFWrapper22(name)
+}
diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt
new file mode 100644
index 00000000..044ec399
--- /dev/null
+++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt
@@ -0,0 +1,164 @@
+/*-
+ * =LICENSE=
+ * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
+ * ----------
+ * Copyright (C) 2019 - 2021 JetBrains
+ * ----------
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =LICENSEEND=
+ */
+package org.jetbrains.kotlinx.spark.api
+
+import io.kotest.core.spec.style.ShouldSpec
+import io.kotest.matchers.shouldBe
+import org.apache.spark.sql.Dataset
+import org.junit.jupiter.api.assertThrows
+import scala.collection.JavaConversions
+import scala.collection.mutable.WrappedArray
+
+@Suppress("unused")
+private fun scala.collection.Iterable.asIterable(): Iterable = JavaConversions.asJavaIterable(this)
+
+@Suppress("unused")
+class UDFRegisterTest : ShouldSpec({
+ context("org.jetbrains.kotlinx.spark.api.UDFRegister") {
+ context("the function checkForValidType") {
+ val invalidTypes = listOf(
+ Array::class,
+ Iterable::class,
+ List::class,
+ MutableList::class,
+ ByteArray::class,
+ CharArray::class,
+ ShortArray::class,
+ IntArray::class,
+ LongArray::class,
+ FloatArray::class,
+ DoubleArray::class,
+ BooleanArray::class,
+ Map::class,
+ MutableMap::class,
+ Set::class,
+ MutableSet::class,
+ arrayOf("")::class,
+ listOf("")::class,
+ setOf("")::class,
+ mapOf("" to "")::class,
+ mutableListOf("")::class,
+ mutableSetOf("")::class,
+ mutableMapOf("" to "")::class,
+ )
+ invalidTypes.forEachIndexed { index, invalidType ->
+ should("$index: throw an ${TypeOfUDFParameterNotSupportedException::class.simpleName} when encountering ${invalidType.qualifiedName}") {
+ assertThrows {
+ invalidType.checkForValidType("test")
+ }
+ }
+ }
+ }
+
+ context("the register-function") {
+ withSpark {
+
+ should("fail when using a simple kotlin.Array") {
+ assertThrows {
+ udf.register("shouldFail") { array: Array ->
+ array.joinToString(" ")
+ }
+ }
+ }
+
+ should("succeed when using a WrappedArray") {
+ udf.register("shouldSucceed") { array: WrappedArray ->
+ array.asIterable().joinToString(" ")
+ }
+ }
+
+ should("succeed when return a List") {
+ udf.register>("StringToIntList") { a ->
+ a.asIterable().map { it.code }
+ }
+
+ val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList()
+ result shouldBe listOf(listOf(97, 98))
+ }
+
+ should("succeed when using three type udf and as result to udf return type") {
+ listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1")
+ udf.register("stringIntDiff") { a, b ->
+ a[0].code - b
+ }
+ val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList()
+ result shouldBe listOf(96, 96)
+ }
+ }
+ }
+
+ context("calling the UDF-Wrapper") {
+ withSpark(logLevel = SparkLogLevel.DEBUG) {
+ should("succeed call UDF-Wrapper in withColumn") {
+
+ val stringArrayMerger = udf.register, String>("stringArrayMerger") {
+ it.asIterable().joinToString(" ")
+ }
+
+ val testData = dsOf(listOf("a", "b"))
+ val newData = testData.withColumn("text", stringArrayMerger(testData.col("value")))
+
+ newData.select("text").collectAsList().zip(newData.select("value").collectAsList())
+ .forEach { (text, textArray) ->
+ assert(text.getString(0) == textArray.getList(0).joinToString(" "))
+ }
+ }
+
+
+ should("succeed in dataset") {
+ val dataset: Dataset = listOf(NormalClass("a", 10), NormalClass("b", 20)).toDS()
+
+ val udfWrapper = udf.register("nameConcatAge") { name, age ->
+ "$name-$age"
+ }
+
+ val collectAsList = dataset.withColumn(
+ "nameAndAge",
+ udfWrapper(dataset.col("name"), dataset.col("age"))
+ )
+ .select("nameAndAge")
+ .collectAsList()
+
+ collectAsList[0][0] shouldBe "a-10"
+ collectAsList[1][0] shouldBe "b-20"
+ }
+ }
+ }
+
+ // get the same exception with: https://forums.databricks.com/questions/13361/how-do-i-create-a-udf-in-java-which-return-a-compl.html
+// context("udf return data class") {
+// withSpark(logLevel = SparkLogLevel.DEBUG) {
+// should("return NormalClass") {
+// listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2")
+// udf.register("toNormalClass") { a, b ->
+// NormalClass(a,b)
+// }
+// spark.sql("select toNormalClass(first, second) from test2").show()
+// }
+// }
+// }
+
+ }
+})
+
+data class NormalClass(
+ val name: String,
+ val age: Int
+)
diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
index d5d6aa2c..3ef0b177 100644
--- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
+++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
@@ -20,6 +20,7 @@
package org.jetbrains.kotlinx.spark.api
import org.apache.spark.sql.SparkSession.Builder
+import org.apache.spark.sql.UDFRegistration
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
/**
@@ -78,4 +79,5 @@ inline class KSparkSession(val spark: SparkSession) {
inline fun List.toDS() = toDS(spark)
inline fun Array.toDS() = spark.dsOf(*this)
inline fun dsOf(vararg arg: T) = spark.dsOf(*arg)
+ val udf: UDFRegistration get() = spark.udf()
}
diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt
new file mode 100644
index 00000000..6dc19d58
--- /dev/null
+++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegister.kt
@@ -0,0 +1,1375 @@
+/*-
+ * =LICENSE=
+ * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
+ * ----------
+ * Copyright (C) 2019 - 2021 JetBrains
+ * ----------
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =LICENSEEND=
+ */
+@file:Suppress("DuplicatedCode")
+
+package org.jetbrains.kotlinx.spark.api
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.DataTypeWithClass
+import org.apache.spark.sql.UDFRegistration
+import org.apache.spark.sql.api.java.*
+import org.apache.spark.sql.functions
+import org.apache.spark.sql.types.DataType
+import scala.collection.mutable.WrappedArray
+import kotlin.reflect.KClass
+import kotlin.reflect.full.isSubclassOf
+import kotlin.reflect.typeOf
+
+fun DataType.unWrapper(): DataType {
+ return when (this) {
+ is DataTypeWithClass -> DataType.fromJson(dt().json())
+ else -> this
+ }
+}
+
+/**
+ * Checks if [this] is of a valid type for an UDF, otherwise it throws a [TypeOfUDFParameterNotSupportedException]
+ */
+@PublishedApi
+internal fun KClass<*>.checkForValidType(parameterName: String) {
+ if (this == String::class || isSubclassOf(WrappedArray::class)) return // Most of the time we need strings or WrappedArrays
+ if (isSubclassOf(Iterable::class) || java.isArray
+ || isSubclassOf(Map::class) || isSubclassOf(Array::class)
+ || isSubclassOf(ByteArray::class) || isSubclassOf(CharArray::class)
+ || isSubclassOf(ShortArray::class) || isSubclassOf(IntArray::class)
+ || isSubclassOf(LongArray::class) || isSubclassOf(FloatArray::class)
+ || isSubclassOf(DoubleArray::class) || isSubclassOf(BooleanArray::class)
+ ) {
+ throw TypeOfUDFParameterNotSupportedException(this, parameterName)
+ }
+}
+
+/**
+ * An exception thrown when the UDF is generated with illegal types for the parameters
+ */
+class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: String) : IllegalArgumentException(
+ "Parameter $parameterName is subclass of ${kClass.qualifiedName}. If you need to process an array use ${WrappedArray::class.qualifiedName}."
+)
+
+/**
+ * A wrapper for an UDF with 0 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper0(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(): Column {
+ return functions.callUDF(udfName)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(name: String, noinline func: () -> R): UDFWrapper0 {
+ register(name, UDF0(func), schema(typeOf()).unWrapper())
+ return UDFWrapper0(name)
+}
+
+/**
+ * A wrapper for an UDF with 1 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper1(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column): Column {
+ return functions.callUDF(udfName, param0)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(name: String, noinline func: (T0) -> R): UDFWrapper1 {
+ T0::class.checkForValidType("T0")
+ register(name, UDF1(func), schema(typeOf()).unWrapper())
+ return UDFWrapper1(name)
+}
+
+/**
+ * A wrapper for an UDF with 2 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper2(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column): Column {
+ return functions.callUDF(udfName, param0, param1)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1) -> R
+): UDFWrapper2 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ register(name, UDF2(func), schema(typeOf()).unWrapper())
+ return UDFWrapper2(name)
+}
+
+/**
+ * A wrapper for an UDF with 3 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper3(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2) -> R
+): UDFWrapper3 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ register(name, UDF3(func), schema(typeOf()).unWrapper())
+ return UDFWrapper3(name)
+}
+
+/**
+ * A wrapper for an UDF with 4 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper4(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3) -> R
+): UDFWrapper4 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ register(name, UDF4(func), schema(typeOf()).unWrapper())
+ return UDFWrapper4(name)
+}
+
+/**
+ * A wrapper for an UDF with 5 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper5(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(param0: Column, param1: Column, param2: Column, param3: Column, param4: Column): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4) -> R
+): UDFWrapper5 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ register(name, UDF5(func), schema(typeOf()).unWrapper())
+ return UDFWrapper5(name)
+}
+
+/**
+ * A wrapper for an UDF with 6 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper6(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5) -> R
+): UDFWrapper6 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ register(name, UDF6(func), schema(typeOf()).unWrapper())
+ return UDFWrapper6(name)
+}
+
+/**
+ * A wrapper for an UDF with 7 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper7(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6) -> R
+): UDFWrapper7 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ register(name, UDF7(func), schema(typeOf()).unWrapper())
+ return UDFWrapper7(name)
+}
+
+/**
+ * A wrapper for an UDF with 8 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper8(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7) -> R
+): UDFWrapper8 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ register(name, UDF8(func), schema(typeOf()).unWrapper())
+ return UDFWrapper8(name)
+}
+
+/**
+ * A wrapper for an UDF with 9 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper9(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column
+ ): Column {
+ return functions.callUDF(udfName, param0, param1, param2, param3, param4, param5, param6, param7, param8)
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8) -> R
+): UDFWrapper9 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ register(name, UDF9(func), schema(typeOf()).unWrapper())
+ return UDFWrapper9(name)
+}
+
+/**
+ * A wrapper for an UDF with 10 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper10(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> R
+): UDFWrapper10 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ register(name, UDF10(func), schema(typeOf()).unWrapper())
+ return UDFWrapper10(name)
+}
+
+/**
+ * A wrapper for an UDF with 11 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper11(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) -> R
+): UDFWrapper11 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ register(name, UDF11(func), schema(typeOf()).unWrapper())
+ return UDFWrapper11(name)
+}
+
+/**
+ * A wrapper for an UDF with 12 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper12(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) -> R
+): UDFWrapper12 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ register(name, UDF12(func), schema(typeOf()).unWrapper())
+ return UDFWrapper12(name)
+}
+
+/**
+ * A wrapper for an UDF with 13 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper13(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) -> R
+): UDFWrapper13 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ register(name, UDF13(func), schema(typeOf()).unWrapper())
+ return UDFWrapper13(name)
+}
+
+/**
+ * A wrapper for an UDF with 14 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper14(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) -> R
+): UDFWrapper14 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ register(name, UDF14(func), schema(typeOf()).unWrapper())
+ return UDFWrapper14(name)
+}
+
+/**
+ * A wrapper for an UDF with 15 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper15(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) -> R
+): UDFWrapper15 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ register(name, UDF15(func), schema(typeOf()).unWrapper())
+ return UDFWrapper15(name)
+}
+
+/**
+ * A wrapper for an UDF with 16 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper16(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15) -> R
+): UDFWrapper16 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ register(name, UDF16(func), schema(typeOf()).unWrapper())
+ return UDFWrapper16(name)
+}
+
+/**
+ * A wrapper for an UDF with 17 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper17(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> R
+): UDFWrapper17 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ register(name, UDF17(func), schema(typeOf()).unWrapper())
+ return UDFWrapper17(name)
+}
+
+/**
+ * A wrapper for an UDF with 18 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper18(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17) -> R
+): UDFWrapper18 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ register(name, UDF18(func), schema(typeOf()).unWrapper())
+ return UDFWrapper18(name)
+}
+
+/**
+ * A wrapper for an UDF with 19 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper19(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18) -> R
+): UDFWrapper19 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ register(name, UDF19(func), schema(typeOf()).unWrapper())
+ return UDFWrapper19(name)
+}
+
+/**
+ * A wrapper for an UDF with 20 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper20(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19) -> R
+): UDFWrapper20 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ register(name, UDF20(func), schema(typeOf()).unWrapper())
+ return UDFWrapper20(name)
+}
+
+/**
+ * A wrapper for an UDF with 21 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper21(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column,
+ param20: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19,
+ param20
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20) -> R
+): UDFWrapper21 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ T20::class.checkForValidType("T20")
+ register(name, UDF21(func), schema(typeOf()).unWrapper())
+ return UDFWrapper21(name)
+}
+
+/**
+ * A wrapper for an UDF with 22 arguments.
+ * @property udfName the name of the UDF
+ */
+class UDFWrapper22(private val udfName: String) {
+ /**
+ * Calls the [functions.callUDF] for the UDF with the [udfName] and the given columns.
+ */
+ operator fun invoke(
+ param0: Column,
+ param1: Column,
+ param2: Column,
+ param3: Column,
+ param4: Column,
+ param5: Column,
+ param6: Column,
+ param7: Column,
+ param8: Column,
+ param9: Column,
+ param10: Column,
+ param11: Column,
+ param12: Column,
+ param13: Column,
+ param14: Column,
+ param15: Column,
+ param16: Column,
+ param17: Column,
+ param18: Column,
+ param19: Column,
+ param20: Column,
+ param21: Column
+ ): Column {
+ return functions.callUDF(
+ udfName,
+ param0,
+ param1,
+ param2,
+ param3,
+ param4,
+ param5,
+ param6,
+ param7,
+ param8,
+ param9,
+ param10,
+ param11,
+ param12,
+ param13,
+ param14,
+ param15,
+ param16,
+ param17,
+ param18,
+ param19,
+ param20,
+ param21
+ )
+ }
+}
+
+/**
+ * Registers the [func] with its [name] in [this]
+ */
+@OptIn(ExperimentalStdlibApi::class)
+inline fun UDFRegistration.register(
+ name: String,
+ noinline func: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21) -> R
+): UDFWrapper22 {
+ T0::class.checkForValidType("T0")
+ T1::class.checkForValidType("T1")
+ T2::class.checkForValidType("T2")
+ T3::class.checkForValidType("T3")
+ T4::class.checkForValidType("T4")
+ T5::class.checkForValidType("T5")
+ T6::class.checkForValidType("T6")
+ T7::class.checkForValidType("T7")
+ T8::class.checkForValidType("T8")
+ T9::class.checkForValidType("T9")
+ T10::class.checkForValidType("T10")
+ T11::class.checkForValidType("T11")
+ T12::class.checkForValidType("T12")
+ T13::class.checkForValidType("T13")
+ T14::class.checkForValidType("T14")
+ T15::class.checkForValidType("T15")
+ T16::class.checkForValidType("T16")
+ T17::class.checkForValidType("T17")
+ T18::class.checkForValidType("T18")
+ T19::class.checkForValidType("T19")
+ T20::class.checkForValidType("T20")
+ T21::class.checkForValidType("T21")
+ register(name, UDF22(func), schema(typeOf()).unWrapper())
+ return UDFWrapper22(name)
+}
diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt
new file mode 100644
index 00000000..1926be42
--- /dev/null
+++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFRegisterTest.kt
@@ -0,0 +1,166 @@
+/*-
+ * =LICENSE=
+ * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
+ * ----------
+ * Copyright (C) 2019 - 2021 JetBrains
+ * ----------
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =LICENSEEND=
+ */
+package org.jetbrains.kotlinx.spark.api
+
+import io.kotest.core.spec.style.ShouldSpec
+import io.kotest.matchers.shouldBe
+import org.apache.spark.sql.Dataset
+import org.junit.jupiter.api.assertThrows
+import scala.collection.JavaConverters
+import scala.collection.mutable.WrappedArray
+
+@Suppress("unused")
+private fun scala.collection.Iterable.asIterable(): Iterable = JavaConverters.asJavaIterable(this)
+
+@Suppress("unused")
+class UDFRegisterTest : ShouldSpec({
+ context("org.jetbrains.kotlinx.spark.api.UDFRegister") {
+ context("the function checkForValidType") {
+ val invalidTypes = listOf(
+ Array::class,
+ Iterable::class,
+ List::class,
+ MutableList::class,
+ ByteArray::class,
+ CharArray::class,
+ ShortArray::class,
+ IntArray::class,
+ LongArray::class,
+ FloatArray::class,
+ DoubleArray::class,
+ BooleanArray::class,
+ Map::class,
+ MutableMap::class,
+ Set::class,
+ MutableSet::class,
+ arrayOf("")::class,
+ listOf("")::class,
+ setOf("")::class,
+ mapOf("" to "")::class,
+ mutableListOf("")::class,
+ mutableSetOf("")::class,
+ mutableMapOf("" to "")::class,
+ )
+ invalidTypes.forEachIndexed { index, invalidType ->
+ should("$index: throw an ${TypeOfUDFParameterNotSupportedException::class.simpleName} when encountering ${invalidType.qualifiedName}") {
+ assertThrows {
+ invalidType.checkForValidType("test")
+ }
+ }
+ }
+ }
+
+ context("the register-function") {
+ withSpark {
+
+ should("fail when using a simple kotlin.Array") {
+ assertThrows {
+ udf.register("shouldFail") { array: Array ->
+ array.joinToString(" ")
+ }
+ }
+ }
+
+ should("succeed when using a WrappedArray") {
+ udf.register("shouldSucceed") { array: WrappedArray ->
+ array.asIterable().joinToString(" ")
+ }
+ }
+
+ should("succeed when return a List") {
+ udf.register>("StringToIntList") { a ->
+ a.asIterable().map { it.code }
+ }
+
+ val result = spark.sql("select StringToIntList('ab')").`as`>().collectAsList()
+ result shouldBe listOf(listOf(97, 98))
+ }
+
+ should("succeed when using three type udf and as result to udf return type") {
+ listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test1")
+ udf.register("stringIntDiff") { a, b ->
+ a[0].code - b
+ }
+ val result = spark.sql("select stringIntDiff(first, second) from test1").`as`().collectAsList()
+ result shouldBe listOf(96, 96)
+ }
+ }
+ }
+
+ context("calling the UDF-Wrapper") {
+ withSpark(logLevel = SparkLogLevel.DEBUG) {
+ should("succeed call UDF-Wrapper in withColumn") {
+
+ val stringArrayMerger = udf.register, String>("stringArrayMerger") {
+ it.asIterable().joinToString(" ")
+ }
+
+ val testData = dsOf(listOf("a", "b"))
+ val newData = testData.withColumn("text", stringArrayMerger(testData.col("value")))
+
+ newData.select("text").collectAsList().zip(newData.select("value").collectAsList())
+ .forEach { (text, textArray) ->
+ assert(text.getString(0) == textArray.getList(0).joinToString(" "))
+ }
+ }
+
+
+ should("succeed in dataset") {
+ val dataset: Dataset = listOf(
+ NormalClass(name="a", age =10),
+ NormalClass(name="b", age =20)
+ ).toDS()
+
+ val udfWrapper = udf.register("nameConcatAge") { name, age ->
+ "$name-$age"
+ }
+
+ val collectAsList = dataset.withColumn(
+ "nameAndAge",
+ udfWrapper(dataset.col("name"), dataset.col("age"))
+ )
+ .select("nameAndAge")
+ .collectAsList()
+
+ collectAsList[0][0] shouldBe "a-10"
+ collectAsList[1][0] shouldBe "b-20"
+ }
+ }
+ }
+
+ context("udf return data class") {
+ withSpark(logLevel = SparkLogLevel.DEBUG) {
+ should("return NormalClass") {
+ listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2")
+ udf.register("toNormalClass") { a, b ->
+ NormalClass(b, a)
+ }
+ spark.sql("select toNormalClass(first, second) from test2").show()
+ }
+ }
+ }
+
+ }
+})
+
+data class NormalClass(
+ val age: Int,
+ val name: String
+)