diff --git a/.github/workflows/generate_docs.yml b/.github/workflows/generate_docs.yml
index 0aee6d85..cdbd1949 100644
--- a/.github/workflows/generate_docs.yml
+++ b/.github/workflows/generate_docs.yml
@@ -25,5 +25,6 @@ jobs:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_branch: docs
publish_dir: ./kotlin-spark-api/3.2/target/dokka
+ force_orphan: true
diff --git a/README.md b/README.md
index 380ee310..498334d2 100644
--- a/README.md
+++ b/README.md
@@ -27,14 +27,14 @@ We have opened a Spark Project Improvement Proposal: [Kotlin support for Apache
- [Code of Conduct](#code-of-conduct)
- [License](#license)
-## Supported versions of Apache Spark #TODO
+## Supported versions of Apache Spark
| Apache Spark | Scala | Kotlin for Apache Spark |
|:------------:|:-----:|:-------------------------------:|
| 3.0.0+ | 2.12 | kotlin-spark-api-3.0:1.0.2 |
| 2.4.1+ | 2.12 | kotlin-spark-api-2.4_2.12:1.0.2 |
| 2.4.1+ | 2.11 | kotlin-spark-api-2.4_2.11:1.0.2 |
-| 3.2.0+ | 2.12 | kotlin-spark-api-2.4_2.12:1.0.3 |
+| 3.2.0+ | 2.12 | kotlin-spark-api-3.2:1.0.3 |
## Releases
diff --git a/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala b/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
index be808af0..05ff330b 100644
--- a/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
+++ b/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.ScalaReflection.{Schema, dataTypeFor, getClassFromType, isSubtype, javaBoxedType, localTypeOf}
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -30,8 +31,10 @@ import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.Utils
import java.beans.{Introspector, PropertyDescriptor}
+import java.lang.Exception
/**
@@ -45,944 +48,1228 @@ trait DefinedByConstructorParams
* KotlinReflection is heavily inspired by ScalaReflection and even extends it just to add several methods
*/
object KotlinReflection extends KotlinReflection {
- /**
- * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
- * to a native type, an ObjectType is returned.
- *
- * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
- * system. As a result, ObjectType will be returned for things like boxed Integers.
- */
- private def inferExternalType(cls: Class[_]): DataType = cls match {
- case c if c == java.lang.Boolean.TYPE => BooleanType
- case c if c == java.lang.Byte.TYPE => ByteType
- case c if c == java.lang.Short.TYPE => ShortType
- case c if c == java.lang.Integer.TYPE => IntegerType
- case c if c == java.lang.Long.TYPE => LongType
- case c if c == java.lang.Float.TYPE => FloatType
- case c if c == java.lang.Double.TYPE => DoubleType
- case c if c == classOf[Array[Byte]] => BinaryType
- case _ => ObjectType(cls)
- }
-
- val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
-
- // Since we are creating a runtime mirror using the class loader of current thread,
- // we need to use def at here. So, every time we call mirror, it is using the
- // class loader of the current thread.
- override def mirror: universe.Mirror = {
- universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
- }
-
- import universe._
-
- // The Predef.Map is scala.collection.immutable.Map.
- // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
- import scala.collection.Map
-
-
- def isSubtype(t: universe.Type, t2: universe.Type): Boolean = t <:< t2
-
- /**
- * Synchronize to prevent concurrent usage of `<:<` operator.
- * This operator is not thread safe in any current version of scala; i.e.
- * (2.11.12, 2.12.10, 2.13.0-M5).
- *
- * See https://github.com/scala/bug/issues/10766
- */
- /*
- private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
- ScalaReflection.ScalaSubtypeLock.synchronized {
- tpe1 <:< tpe2
- }
- }
- */
-
- private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
- tpe.dealias match {
- case t if isSubtype(t, definitions.NullTpe) => NullType
- case t if isSubtype(t, definitions.IntTpe) => IntegerType
- case t if isSubtype(t, definitions.LongTpe) => LongType
- case t if isSubtype(t, definitions.DoubleTpe) => DoubleType
- case t if isSubtype(t, definitions.FloatTpe) => FloatType
- case t if isSubtype(t, definitions.ShortTpe) => ShortType
- case t if isSubtype(t, definitions.ByteTpe) => ByteType
- case t if isSubtype(t, definitions.BooleanTpe) => BooleanType
- case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType
- case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType
- case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT
- case _ =>
- val className = getClassNameFromType(tpe)
- className match {
- case "scala.Array" =>
- val TypeRef(_, _, Seq(elementType)) = tpe
- arrayClassFor(elementType)
- case _ =>
- val clazz = getClassFromType(tpe)
- ObjectType(clazz)
- }
+ /**
+ * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
+ * to a native type, an ObjectType is returned.
+ *
+ * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
+ * system. As a result, ObjectType will be returned for things like boxed Integers.
+ */
+ private def inferExternalType(cls: Class[_]): DataType = cls match {
+ case c if c == java.lang.Boolean.TYPE => BooleanType
+ case c if c == java.lang.Byte.TYPE => ByteType
+ case c if c == java.lang.Short.TYPE => ShortType
+ case c if c == java.lang.Integer.TYPE => IntegerType
+ case c if c == java.lang.Long.TYPE => LongType
+ case c if c == java.lang.Float.TYPE => FloatType
+ case c if c == java.lang.Double.TYPE => DoubleType
+ case c if c == classOf[Array[Byte]] => BinaryType
+ case c if c == classOf[Decimal] => DecimalType.SYSTEM_DEFAULT
+ case c if c == classOf[CalendarInterval] => CalendarIntervalType
+ case _ => ObjectType(cls)
}
- }
-
- /**
- * Given a type `T` this function constructs `ObjectType` that holds a class of type
- * `Array[T]`.
- *
- * Special handling is performed for primitive types to map them back to their raw
- * JVM form instead of the Scala Array that handles auto boxing.
- */
- private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
- val cls = tpe.dealias match {
- case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]]
- case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]]
- case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]]
- case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]]
- case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
- case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
- case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
- case other =>
- // There is probably a better way to do this, but I couldn't find it...
- val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
- java.lang.reflect.Array.newInstance(elementType, 0).getClass
+ val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
+
+ // Since we are creating a runtime mirror using the class loader of current thread,
+ // we need to use def at here. So, every time we call mirror, it is using the
+ // class loader of the current thread.
+ override def mirror: universe.Mirror = {
+ universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
}
- ObjectType(cls)
- }
-
- /**
- * Returns true if the value of this data type is same between internal and external.
- */
- def isNativeType(dt: DataType): Boolean = dt match {
- case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType | CalendarIntervalType => true
- case _ => false
- }
-
- private def baseType(tpe: `Type`): `Type` = {
- tpe.dealias match {
- case annotatedType: AnnotatedType => annotatedType.underlying
- case other => other
+
+ import universe._
+
+ // The Predef.Map is scala.collection.immutable.Map.
+ // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+ import scala.collection.Map
+
+
+ def isSubtype(t: universe.Type, t2: universe.Type): Boolean = t <:< t2
+
+ /**
+ * Synchronize to prevent concurrent usage of `<:<` operator.
+ * This operator is not thread safe in any current version of scala; i.e.
+ * (2.11.12, 2.12.10, 2.13.0-M5).
+ *
+ * See https://github.com/scala/bug/issues/10766
+ */
+ /*
+ private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
+ ScalaReflection.ScalaSubtypeLock.synchronized {
+ tpe1 <:< tpe2
+ }
+ }
+ */
+
+ private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
+ tpe.dealias match {
+ case t if isSubtype(t, definitions.NullTpe) => NullType
+ case t if isSubtype(t, definitions.IntTpe) => IntegerType
+ case t if isSubtype(t, definitions.LongTpe) => LongType
+ case t if isSubtype(t, definitions.DoubleTpe) => DoubleType
+ case t if isSubtype(t, definitions.FloatTpe) => FloatType
+ case t if isSubtype(t, definitions.ShortTpe) => ShortType
+ case t if isSubtype(t, definitions.ByteTpe) => ByteType
+ case t if isSubtype(t, definitions.BooleanTpe) => BooleanType
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType
+ case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT
+ case _ => {
+ val className = getClassNameFromType(tpe)
+ className match {
+ case "scala.Array" => {
+ val TypeRef(_, _, Seq(elementType)) = tpe
+ arrayClassFor(elementType)
+ }
+ case _ => {
+ val clazz = getClassFromType(tpe)
+ ObjectType(clazz)
+ }
+ }
+ }
+ }
}
- }
-
- /**
- * Returns an expression that can be used to deserialize a Spark SQL representation to an object
- * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of
- * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
- * `UnresolvedExtractValue`.
- *
- * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
- * deserializer expression when using it.
- */
- def deserializerForType(tpe: `Type`): Expression = {
- val clsName = getClassNameFromType(tpe)
- val walkedTypePath = WalkedTypePath().recordRoot(clsName)
- val Schema(dataType, nullable) = schemaFor(tpe)
-
- // Assumes we are deserializing the first column of a row.
- deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType,
- nullable = nullable, walkedTypePath,
- (casted, typePath) => deserializerFor(tpe, casted, typePath))
- }
-
-
- /**
- * Returns an expression that can be used to deserialize an input expression to an object of type
- * `T` with a compatible schema.
- *
- * @param tpe The `Type` of deserialized object.
- * @param path The expression which can be used to extract serialized value.
- * @param walkedTypePath The paths from top to bottom to access current field when deserializing.
- */
- private def deserializerFor(
- tpe: `Type`,
- path: Expression,
- walkedTypePath: WalkedTypePath,
- predefinedDt: Option[DataTypeWithClass] = None
- ): Expression = cleanUpReflectionObjects {
- baseType(tpe) match {
-
- //
- case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Integer])
-
- case t if isSubtype(t, localTypeOf[Int]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Integer])
-
- case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Long])
- case t if isSubtype(t, localTypeOf[Long]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Long])
-
- case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Double])
- case t if isSubtype(t, localTypeOf[Double]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Double])
-
- case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Float])
- case t if isSubtype(t, localTypeOf[Float]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Float])
-
- case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Short])
- case t if isSubtype(t, localTypeOf[Short]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Short])
-
- case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Byte])
- case t if isSubtype(t, localTypeOf[Byte]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Byte])
-
- case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Boolean])
- case t if isSubtype(t, localTypeOf[Boolean]) =>
- createDeserializerForTypesSupportValueOf(path,
- classOf[java.lang.Boolean])
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
- createDeserializerForLocalDate(path)
-
- case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
- createDeserializerForSqlDate(path)
- //
-
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- createDeserializerForInstant(path)
-
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- createDeserializerForSqlTimestamp(path)
-
- case t if isSubtype(t, localTypeOf[java.lang.String]) =>
- createDeserializerForString(path, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- createDeserializerForJavaBigDecimal(path, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- createDeserializerForScalaBigDecimal(path, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- createDeserializerForJavaBigInteger(path, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- createDeserializerForScalaBigInt(path)
-
- case t if isSubtype(t, localTypeOf[Array[_]]) =>
- var TypeRef(_, _, Seq(elementType)) = t
- if (predefinedDt.isDefined && !elementType.dealias.typeSymbol.isClass)
- elementType = getType(predefinedDt.get.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass].cls)
- val Schema(dataType, elementNullable) = predefinedDt.map(it => {
- val elementInfo = it.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass]
- Schema(elementInfo.dt, elementInfo.nullable)
- })
- .getOrElse(schemaFor(elementType))
- val className = getClassNameFromType(elementType)
- val newTypePath = walkedTypePath.recordArray(className)
-
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- (casted, typePath) => deserializerFor(elementType, casted, typePath, predefinedDt.map(_.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType).filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper])))
+
+ /**
+ * Given a type `T` this function constructs `ObjectType` that holds a class of type
+ * `Array[T]`.
+ *
+ * Special handling is performed for primitive types to map them back to their raw
+ * JVM form instead of the Scala Array that handles auto boxing.
+ */
+ private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
+ val cls = tpe.dealias match {
+ case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]]
+ case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]]
+ case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]]
+ case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]]
+ case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
+ case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
+ case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]]
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]]
+ case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]]
+ case other => {
+ // There is probably a better way to do this, but I couldn't find it...
+ val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
+ java.lang.reflect.Array.newInstance(elementType, 0).getClass
+ }
+
}
+ ObjectType(cls)
+ }
- val arrayData = UnresolvedMapObjects(mapFunction, path)
- val arrayCls = arrayClassFor(elementType)
-
- val methodName = elementType match {
- case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
- case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
- case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
- case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
- case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
- case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
- case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
- // non-primitive
- case _ => "array"
+ /**
+ * Returns true if the value of this data type is same between internal and external.
+ */
+ def isNativeType(dt: DataType): Boolean = dt match {
+ case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => {
+ true
}
- Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+ case _ => false
+ }
- // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
- // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
+ private def baseType(tpe: `Type`): `Type` = {
+ tpe.dealias match {
+ case annotatedType: AnnotatedType => annotatedType.underlying
+ case other => other
+ }
+ }
- case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
+ /**
+ * Returns an expression that can be used to deserialize a Spark SQL representation to an object
+ * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of
+ * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
+ * `UnresolvedExtractValue`.
+ *
+ * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
+ * deserializer expression when using it.
+ */
+ def deserializerForType(tpe: `Type`): Expression = {
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = WalkedTypePath().recordRoot(clsName)
+ val Schema(dataType, nullable) = schemaFor(tpe)
+
+ // Assumes we are deserializing the first column of a row.
+ deserializerForWithNullSafetyAndUpcast(
+ GetColumnByOrdinal(0, dataType), dataType,
+ nullable = nullable, walkedTypePath,
+ (casted, typePath) => deserializerFor(tpe, casted, typePath)
+ )
+ }
- val classNameForKey = getClassNameFromType(keyType)
- val classNameForValue = getClassNameFromType(valueType)
- val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
+ /**
+ * Returns an expression that can be used to deserialize an input expression to an object of type
+ * `T` with a compatible schema.
+ *
+ * @param tpe The `Type` of deserialized object.
+ * @param path The expression which can be used to extract serialized value.
+ * @param walkedTypePath The paths from top to bottom to access current field when deserializing.
+ */
+ private def deserializerFor(
+ tpe: `Type`,
+ path: Expression,
+ walkedTypePath: WalkedTypePath,
+ predefinedDt: Option[DataTypeWithClass] = None
+ ): Expression = cleanUpReflectionObjects {
+ baseType(tpe) match {
+
+ //
+ case t if (
+ try {
+ !dataTypeFor(t).isInstanceOf[ObjectType]
+ } catch {
+ case _: Throwable => false
+ }) && !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => {
+ path
+ }
- UnresolvedCatalystToExternalMap(
- path,
- p => deserializerFor(keyType, p, newTypePath),
- p => deserializerFor(valueType, p, newTypePath),
- mirror.runtimeClass(t.typeSymbol.asClass)
- )
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer])
+ }
+ case t if isSubtype(t, localTypeOf[Int]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long])
+ }
+ case t if isSubtype(t, localTypeOf[Long]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double])
+ }
+ case t if isSubtype(t, localTypeOf[Double]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float])
+ }
+ case t if isSubtype(t, localTypeOf[Float]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short])
+ }
+ case t if isSubtype(t, localTypeOf[Short]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte])
+ }
+ case t if isSubtype(t, localTypeOf[Byte]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte])
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean])
+ }
+ case t if isSubtype(t, localTypeOf[Boolean]) => {
+ createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean])
+ }
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => {
+ createDeserializerForLocalDate(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => {
+ createDeserializerForSqlDate(path)
+ } //
- case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
- createDeserializerForTypesSupportValueOf(
- createDeserializerForString(path, returnNullable = false), Class.forName(t.toString))
-
- case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
- getConstructor().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
-
- case _ if predefinedDt.isDefined =>
- predefinedDt.get match {
- case wrapper: KDataTypeWrapper =>
- val structType = wrapper.dt
- val cls = wrapper.cls
- val arguments = structType
- .fields
- .map(field => {
- val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
- val nullable = dataType.nullable
- val clsName = getClassNameFromType(getType(dataType.cls))
- val newTypePath = walkedTypePath.recordField(clsName, field.name)
-
- // For tuples, we based grab the inner fields by ordinal instead of name.
- val newPath = deserializerFor(
- getType(dataType.cls),
- addToPath(path, field.name, dataType.dt, newTypePath),
- newTypePath,
- Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
- )
- expressionWithNullSafety(
- newPath,
- nullable = nullable,
- newTypePath
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) => {
+ createDeserializerForInstant(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => {
+ createDeserializerForTypesSupportValueOf(
+ Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
+ getClassFromType(t),
)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => {
+ createDeserializerForSqlTimestamp(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => {
+ createDeserializerForLocalDateTime(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Duration]) => {
+ createDeserializerForDuration(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Period]) => {
+ createDeserializerForPeriod(path)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.String]) => {
+ createDeserializerForString(path, returnNullable = false)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => {
+ createDeserializerForJavaBigDecimal(path, returnNullable = false)
+ }
+ case t if isSubtype(t, localTypeOf[BigDecimal]) => {
+ createDeserializerForScalaBigDecimal(path, returnNullable = false)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => {
+ createDeserializerForJavaBigInteger(path, returnNullable = false)
+ }
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => {
+ createDeserializerForScalaBigInt(path)
+ }
- })
- val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+ case t if isSubtype(t, localTypeOf[Array[_]]) => {
+ var TypeRef(_, _, Seq(elementType)) = t
+ if (predefinedDt.isDefined && !elementType.dealias.typeSymbol.isClass)
+ elementType = getType(predefinedDt.get.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType]
+ .elementType.asInstanceOf[DataTypeWithClass].cls
+ )
+ val Schema(dataType, elementNullable) = predefinedDt.map { it =>
+ val elementInfo = it.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType
+ .asInstanceOf[DataTypeWithClass]
+ Schema(elementInfo.dt, elementInfo.nullable)
+ }.getOrElse(schemaFor(elementType))
+ val className = getClassNameFromType(elementType)
+ val newTypePath = walkedTypePath.recordArray(className)
- org.apache.spark.sql.catalyst.expressions.If(
- IsNull(path),
- org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
- newInstance
- )
+ val mapFunction: Expression => Expression = element => {
+ // upcast the array element to the data type the encoder expected.
+ deserializerForWithNullSafetyAndUpcast(
+ element,
+ dataType,
+ nullable = elementNullable,
+ newTypePath,
+ (casted, typePath) => deserializerFor(
+ tpe = elementType,
+ path = casted,
+ walkedTypePath = typePath,
+ predefinedDt = predefinedDt
+ .map(_.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ .map(_.asInstanceOf[ComplexWrapper])
+ )
+ )
+ }
+
+ val arrayData = UnresolvedMapObjects(mapFunction, path)
+ val arrayCls = arrayClassFor(elementType)
+
+ val methodName = elementType match {
+ case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
+ case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
+ case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
+ case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
+ case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
+ case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
+ case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
+ // non-primitive
+ case _ => "array"
+ }
+ Invoke(arrayData, methodName, arrayCls, returnNullable = false)
+ }
+
+ // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
+ // to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
+
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) => {
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
- case t: ComplexWrapper =>
- t.dt match {
- case MapType(kt, vt, _) =>
- val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls).map(getType(_))
- val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass])
val classNameForKey = getClassNameFromType(keyType)
val classNameForValue = getClassNameFromType(valueType)
val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
- val keyData =
- Invoke(
- UnresolvedMapObjects(
- p => deserializerFor(keyType, p, newTypePath, Some(keyDT).filter(_.isInstanceOf[ComplexWrapper])),
- MapKeys(path)),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- val valueData =
- Invoke(
- UnresolvedMapObjects(
- p => deserializerFor(valueType, p, newTypePath, Some(valueDT).filter(_.isInstanceOf[ComplexWrapper])),
- MapValues(path)),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- StaticInvoke(
- ArrayBasedMapData.getClass,
- ObjectType(classOf[java.util.Map[_, _]]),
- "toJavaMap",
- keyData :: valueData :: Nil,
- returnNullable = false)
-
- case ArrayType(elementType, containsNull) =>
- val dataTypeWithClass = elementType.asInstanceOf[DataTypeWithClass]
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- val et = getType(dataTypeWithClass.cls)
- val className = getClassNameFromType(et)
- val newTypePath = walkedTypePath.recordArray(className)
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataTypeWithClass.dt,
- nullable = dataTypeWithClass.nullable,
- newTypePath,
- (casted, typePath) => {
- deserializerFor(et, casted, typePath, Some(dataTypeWithClass).filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper]))
- })
+ UnresolvedCatalystToExternalMap(
+ path,
+ p => deserializerFor(keyType, p, newTypePath),
+ p => deserializerFor(valueType, p, newTypePath),
+ mirror.runtimeClass(t.typeSymbol.asClass)
+ )
+ }
+
+ case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => {
+ createDeserializerForTypesSupportValueOf(
+ createDeserializerForString(path, returnNullable = false),
+ Class.forName(t.toString),
+ )
+ }
+ case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => {
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
+ getConstructor().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())
+ )
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+ }
+
+ case t if UDTRegistration.exists(getClassNameFromType(t)) => {
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass)
+ )
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
+ }
+
+ case _ if predefinedDt.isDefined => {
+ predefinedDt.get match {
+
+ case wrapper: KDataTypeWrapper => {
+ val structType = wrapper.dt
+ val cls = wrapper.cls
+ val arguments = structType
+ .fields
+ .map { field =>
+ val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
+ val nullable = dataType.nullable
+ val clsName = getClassNameFromType(getType(dataType.cls))
+ val newTypePath = walkedTypePath.recordField(clsName, field.name)
+
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ val newPath = deserializerFor(
+ tpe = getType(dataType.cls),
+ path = addToPath(path, field.name, dataType.dt, newTypePath),
+ walkedTypePath = newTypePath,
+ predefinedDt = Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
+ )
+ expressionWithNullSafety(
+ newPath,
+ nullable = nullable,
+ newTypePath
+ )
+ }
+ val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+
+ org.apache.spark.sql.catalyst.expressions.If(
+ IsNull(path),
+ org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
+ newInstance
+ )
+ }
+
+ case t: ComplexWrapper => {
+
+ t.dt match {
+ case MapType(kt, vt, _) => {
+ val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls)
+ .map(getType(_))
+ val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass])
+ val classNameForKey = getClassNameFromType(keyType)
+ val classNameForValue = getClassNameFromType(valueType)
+
+ val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)
+
+ val keyData =
+ Invoke(
+ UnresolvedMapObjects(
+ p => deserializerFor(
+ keyType, p, newTypePath, Some(keyDT)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ ),
+ MapKeys(path)
+ ),
+ "array",
+ ObjectType(classOf[Array[Any]])
+ )
+
+ val valueData =
+ Invoke(
+ UnresolvedMapObjects(
+ p => deserializerFor(
+ valueType, p, newTypePath, Some(valueDT)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ ),
+ MapValues(path)
+ ),
+ "array",
+ ObjectType(classOf[Array[Any]])
+ )
+
+ StaticInvoke(
+ ArrayBasedMapData.getClass,
+ ObjectType(classOf[java.util.Map[_, _]]),
+ "toJavaMap",
+ keyData :: valueData :: Nil,
+ returnNullable = false
+ )
+ }
+
+ case ArrayType(elementType, containsNull) => {
+ val dataTypeWithClass = elementType.asInstanceOf[DataTypeWithClass]
+ val mapFunction: Expression => Expression = element => {
+ // upcast the array element to the data type the encoder expected.
+ val et = getType(dataTypeWithClass.cls)
+ val className = getClassNameFromType(et)
+ val newTypePath = walkedTypePath.recordArray(className)
+ deserializerForWithNullSafetyAndUpcast(
+ element,
+ dataTypeWithClass.dt,
+ nullable = dataTypeWithClass.nullable,
+ newTypePath,
+ (casted, typePath) => {
+ deserializerFor(
+ et, casted, typePath, Some(dataTypeWithClass)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ .map(_.asInstanceOf[ComplexWrapper])
+ )
+ }
+ )
+ }
+
+ UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls))
+ }
+
+ case StructType(elementType: Array[StructField]) => {
+ val cls = t.cls
+
+ val arguments = elementType.map { field =>
+ val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
+ val nullable = dataType.nullable
+ val clsName = getClassNameFromType(getType(dataType.cls))
+ val newTypePath = walkedTypePath.recordField(clsName, field.name)
+
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ val newPath = deserializerFor(
+ getType(dataType.cls),
+ addToPath(path, field.name, dataType.dt, newTypePath),
+ newTypePath,
+ Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
+ )
+ expressionWithNullSafety(
+ newPath,
+ nullable = nullable,
+ newTypePath
+ )
+ }
+ val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+
+ org.apache.spark.sql.catalyst.expressions.If(
+ IsNull(path),
+ org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
+ newInstance
+ )
+ }
+
+ case _ => {
+ throw new UnsupportedOperationException(
+ s"No Encoder found for $tpe\n" + walkedTypePath
+ )
+ }
+ }
+ }
}
+ }
- UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls))
-
- case StructType(elementType: Array[StructField]) =>
- val cls = t.cls
-
- val arguments = elementType.map { field =>
- val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
- val nullable = dataType.nullable
- val clsName = getClassNameFromType(getType(dataType.cls))
- val newTypePath = walkedTypePath.recordField(clsName, field.name)
-
- // For tuples, we based grab the inner fields by ordinal instead of name.
- val newPath = deserializerFor(
- getType(dataType.cls),
- addToPath(path, field.name, dataType.dt, newTypePath),
- newTypePath,
- Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
- )
- expressionWithNullSafety(
- newPath,
- nullable = nullable,
- newTypePath
- )
+ case t if definedByConstructorParams(t) => {
+ val params = getConstructorParameters(t)
+
+ val cls = getClassFromType(tpe)
+
+ val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
+ val Schema(dataType, nullable) = schemaFor(fieldType)
+ val clsName = getClassNameFromType(fieldType)
+ val newTypePath = walkedTypePath.recordField(clsName, fieldName)
+
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ val newPath = if (cls.getName startsWith "scala.Tuple") {
+ deserializerFor(
+ fieldType,
+ addToPathOrdinal(path, i, dataType, newTypePath),
+ newTypePath
+ )
+ } else {
+ deserializerFor(
+ fieldType,
+ addToPath(path, fieldName, dataType, newTypePath),
+ newTypePath
+ )
+ }
+ expressionWithNullSafety(
+ newPath,
+ nullable = nullable,
+ newTypePath
+ )
}
+
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
org.apache.spark.sql.catalyst.expressions.If(
- IsNull(path),
- org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
- newInstance
+ IsNull(path),
+ org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
+ newInstance
)
+ }
-
- case _ =>
+ case _ => {
throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath)
+ s"No Encoder found for $tpe\n" + walkedTypePath
+ )
}
}
+ }
+
+ /**
+ * Returns an expression for serializing an object of type T to Spark SQL representation. The
+ * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`.
+ *
+ * If the given type is not supported, i.e. there is no encoder can be built for this type,
+ * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
+ * the type path walked so far and which class we are not supporting.
+ * There are 4 kinds of type path:
+ * * the root type: `root class: "abc.xyz.MyClass"`
+ * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
+ * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
+ * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
+ */
+ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = WalkedTypePath().recordRoot(clsName)
+
+ // The input object to `ExpressionEncoder` is located at first column of an row.
+ val isPrimitive = tpe.typeSymbol.asClass.isPrimitive
+ val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive)
+
+ serializerFor(inputObject, tpe, walkedTypePath)
+ }
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
-
- val cls = getClassFromType(tpe)
-
- val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- val clsName = getClassNameFromType(fieldType)
- val newTypePath = walkedTypePath.recordField(clsName, fieldName)
-
- // For tuples, we based grab the inner fields by ordinal instead of name.
- val newPath = if (cls.getName startsWith "scala.Tuple") {
- deserializerFor(
- fieldType,
- addToPathOrdinal(path, i, dataType, newTypePath),
- newTypePath)
- } else {
- deserializerFor(
- fieldType,
- addToPath(path, fieldName, dataType, newTypePath),
- newTypePath)
- }
- expressionWithNullSafety(
- newPath,
- nullable = nullable,
- newTypePath)
+ def getType[T](clazz: Class[T]): universe.Type = {
+ clazz match {
+ case _ if clazz == classOf[Array[Byte]] => localTypeOf[Array[Byte]]
+ case _ => {
+ val mir = runtimeMirror(clazz.getClassLoader)
+ mir.classSymbol(clazz).toType
+ }
}
- val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
+ }
- org.apache.spark.sql.catalyst.expressions.If(
- IsNull(path),
- org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
- newInstance
+ def deserializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = {
+ val tpe = getType(cls)
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = WalkedTypePath().recordRoot(clsName)
+
+ // Assumes we are deserializing the first column of a row.
+ deserializerForWithNullSafetyAndUpcast(
+ GetColumnByOrdinal(0, dt.dt),
+ dt.dt,
+ nullable = dt.nullable,
+ walkedTypePath,
+ (casted, typePath) => deserializerFor(tpe, casted, typePath, Some(dt))
)
-
- case _ =>
- throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath)
}
- }
-
- /**
- * Returns an expression for serializing an object of type T to Spark SQL representation. The
- * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`.
- *
- * If the given type is not supported, i.e. there is no encoder can be built for this type,
- * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
- * the type path walked so far and which class we are not supporting.
- * There are 4 kinds of type path:
- * * the root type: `root class: "abc.xyz.MyClass"`
- * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
- * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
- * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
- */
- def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects {
- val clsName = getClassNameFromType(tpe)
- val walkedTypePath = WalkedTypePath().recordRoot(clsName)
-
- // The input object to `ExpressionEncoder` is located at first column of an row.
- val isPrimitive = tpe.typeSymbol.asClass.isPrimitive
- val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive)
-
- serializerFor(inputObject, tpe, walkedTypePath)
- }
-
- def getType[T](clazz: Class[T]): universe.Type = {
- val mir = runtimeMirror(clazz.getClassLoader)
- mir.classSymbol(clazz).toType
- }
-
- def deserializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = {
- val tpe = getType(cls)
- val clsName = getClassNameFromType(tpe)
- val walkedTypePath = WalkedTypePath().recordRoot(clsName)
-
- // Assumes we are deserializing the first column of a row.
- deserializerForWithNullSafetyAndUpcast(
- GetColumnByOrdinal(0, dt.dt),
- dt.dt,
- nullable = dt.nullable,
- walkedTypePath,
- (casted, typePath) => deserializerFor(tpe, casted, typePath, Some(dt))
- )
- }
-
-
- def serializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = {
-
- val tpe = getType(cls)
- val clsName = getClassNameFromType(tpe)
- val walkedTypePath = WalkedTypePath().recordRoot(clsName)
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- serializerFor(inputObject, tpe, walkedTypePath, predefinedDt = Some(dt))
- }
-
- /**
- * Returns an expression for serializing the value of an input expression into Spark SQL
- * internal representation.
- */
- private def serializerFor(
- inputObject: Expression,
- tpe: `Type`,
- walkedTypePath: WalkedTypePath,
- seenTypeSet: Set[`Type`] = Set.empty,
- predefinedDt: Option[DataTypeWithClass] = None
- ): Expression = cleanUpReflectionObjects {
-
- def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = {
- predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
-
- case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) =>
- val clsName = getClassNameFromType(elementType)
- val newPath = walkedTypePath.recordArray(clsName)
- createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls),
- serializerFor(_, elementType, newPath, seenTypeSet, predefinedDt))
-
- case dt: ObjectType =>
- val clsName = getClassNameFromType(elementType)
- val newPath = walkedTypePath.recordArray(clsName)
- createSerializerForMapObjects(input, dt,
- serializerFor(_, elementType, newPath, seenTypeSet))
-
- case dt@(BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType) =>
- val cls = input.dataType.asInstanceOf[ObjectType].cls
- if (cls.isArray && cls.getComponentType.isPrimitive) {
- createSerializerForPrimitiveArray(input, dt)
- } else {
- createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable))
- }
-
- case _: StringType =>
- val clsName = getClassNameFromType(typeOf[String])
- val newPath = walkedTypePath.recordArray(clsName)
- createSerializerForMapObjects(input, ObjectType(Class.forName(getClassNameFromType(elementType))),
- serializerFor(_, elementType, newPath, seenTypeSet))
-
-
- case dt =>
- createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable))
- }
+
+
+ def serializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = {
+ val tpe = getType(cls)
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = WalkedTypePath().recordRoot(clsName)
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ serializerFor(inputObject, tpe, walkedTypePath, predefinedDt = Some(dt))
}
- baseType(tpe) match {
-
- //
- case _ if !inputObject.dataType.isInstanceOf[ObjectType] && !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => inputObject
-
- case t if isSubtype(t, localTypeOf[Option[_]]) =>
- val TypeRef(_, _, Seq(optType)) = t
- val className = getClassNameFromType(optType)
- val newPath = walkedTypePath.recordOption(className)
- val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
- serializerFor(unwrapped, optType, newPath, seenTypeSet)
-
- // Since List[_] also belongs to localTypeOf[Product], we put this case before
- // "case t if definedByConstructorParams(t)" to make sure it will match to the
- // case "localTypeOf[Seq[_]]"
- case t if isSubtype(t, localTypeOf[Seq[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if isSubtype(t, localTypeOf[Array[_]]) && predefinedDt.isEmpty =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val keyClsName = getClassNameFromType(keyType)
- val valueClsName = getClassNameFromType(valueType)
- val keyPath = walkedTypePath.recordKeyForMap(keyClsName)
- val valuePath = walkedTypePath.recordValueForMap(valueClsName)
-
- createSerializerForMap(
- inputObject,
- MapElementInformation(
- dataTypeFor(keyType),
- nullable = !keyType.typeSymbol.asClass.isPrimitive,
- serializerFor(_, keyType, keyPath, seenTypeSet)),
- MapElementInformation(
- dataTypeFor(valueType),
- nullable = !valueType.typeSymbol.asClass.isPrimitive,
- serializerFor(_, valueType, valuePath, seenTypeSet))
- )
+ /**
+ * Returns an expression for serializing the value of an input expression into Spark SQL
+ * internal representation.
+ */
+ private def serializerFor(
+ inputObject: Expression,
+ tpe: `Type`,
+ walkedTypePath: WalkedTypePath,
+ seenTypeSet: Set[`Type`] = Set.empty,
+ predefinedDt: Option[DataTypeWithClass] = None,
+ ): Expression = cleanUpReflectionObjects {
+
+ def toCatalystArray(
+ input: Expression,
+ elementType: `Type`,
+ predefinedDt: Option[DataTypeWithClass] = None,
+ ): Expression = {
+ val dataType = predefinedDt
+ .map(_.dt)
+ .getOrElse {
+ dataTypeFor(elementType)
+ }
- case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
+ dataType match {
- // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
- // Note that the property of `Set` is only kept when manipulating the data as domain object.
- val newInput =
- Invoke(
- inputObject,
- "toSeq",
- ObjectType(classOf[Seq[_]]))
-
- toCatalystArray(newInput, elementType)
-
- case t if isSubtype(t, localTypeOf[String]) =>
- createSerializerForString(inputObject)
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- createSerializerForJavaInstant(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- createSerializerForSqlTimestamp(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
- createSerializerForJavaLocalDate(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)
-
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- createSerializerForScalaBigDecimal(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- createSerializerForJavaBigDecimal(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- createSerializerForJavaBigInteger(inputObject)
-
- case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- createSerializerForScalaBigInt(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
- createSerializerForInteger(inputObject)
- case t if isSubtype(t, localTypeOf[Int]) =>
- createSerializerForInteger(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Long]) => createSerializerForLong(inputObject)
- case t if isSubtype(t, localTypeOf[Long]) => createSerializerForLong(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Double]) => createSerializerForDouble(inputObject)
- case t if isSubtype(t, localTypeOf[Double]) => createSerializerForDouble(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Float]) => createSerializerForFloat(inputObject)
- case t if isSubtype(t, localTypeOf[Float]) => createSerializerForFloat(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Short]) => createSerializerForShort(inputObject)
- case t if isSubtype(t, localTypeOf[Short]) => createSerializerForShort(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Byte]) => createSerializerForByte(inputObject)
- case t if isSubtype(t, localTypeOf[Byte]) => createSerializerForByte(inputObject)
- case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => createSerializerForBoolean(inputObject)
- case t if isSubtype(t, localTypeOf[Boolean]) => createSerializerForBoolean(inputObject)
-
- case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
- createSerializerForString(
- Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))
-
- case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t)
- .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance()
- val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
- createSerializerForUserDefinedType(inputObject, udt, udtClass)
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[_]]
- val udtClass = udt.getClass
- createSerializerForUserDefinedType(inputObject, udt, udtClass)
- //
-
- case _ if predefinedDt.isDefined =>
- predefinedDt.get match {
- case dataType: KDataTypeWrapper =>
- val cls = dataType.cls
- val properties = getJavaBeanReadableProperties(cls)
- val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField])
- val fields = structFields.map { structField =>
- val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName)
- if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}")
- val fieldName = structField.name
- val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
- val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
- val fieldValue = Invoke(
- inputObject,
- maybeProp.get.getReadMethod.getName,
- inferExternalType(propClass),
- returnNullable = structField.nullable
- )
- val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
- (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None))
-
- }
- createSerializerForObject(inputObject, fields)
-
- case otherTypeWrapper: ComplexWrapper =>
- otherTypeWrapper.dt match {
- case MapType(kt, vt, _) =>
- val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls).map(getType(_))
- val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass])
+ case dt @ (MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => {
+ val clsName = getClassNameFromType(elementType)
+ val newPath = walkedTypePath.recordArray(clsName)
+ createSerializerForMapObjects(
+ input, ObjectType(predefinedDt.get.cls),
+ serializerFor(_, elementType, newPath, seenTypeSet, predefinedDt)
+ )
+ }
+
+ case dt: ObjectType => {
+ val clsName = getClassNameFromType(elementType)
+ val newPath = walkedTypePath.recordArray(clsName)
+ createSerializerForMapObjects(
+ input, dt,
+ serializerFor(_, elementType, newPath, seenTypeSet)
+ )
+ }
+
+ // case dt: ByteType =>
+ // createSerializerForPrimitiveArray(input, dt)
+
+ case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => {
+ val cls = input.dataType.asInstanceOf[ObjectType].cls
+ if (cls.isArray && cls.getComponentType.isPrimitive) {
+ createSerializerForPrimitiveArray(input, dt)
+ } else {
+ createSerializerForGenericArray(
+ inputObject = input,
+ dataType = dt,
+ nullable = predefinedDt
+ .map(_.nullable)
+ .getOrElse(
+ schemaFor(elementType).nullable
+ ),
+ )
+ }
+ }
+
+ case _: StringType => {
+ val clsName = getClassNameFromType(typeOf[String])
+ val newPath = walkedTypePath.recordArray(clsName)
+ createSerializerForMapObjects(
+ input, ObjectType(Class.forName(getClassNameFromType(elementType))),
+ serializerFor(_, elementType, newPath, seenTypeSet)
+ )
+ }
+
+ case dt => {
+ createSerializerForGenericArray(
+ inputObject = input,
+ dataType = dt,
+ nullable = predefinedDt
+ .map(_.nullable)
+ .getOrElse {
+ schemaFor(elementType).nullable
+ },
+ )
+ }
+ }
+ }
+
+ baseType(tpe) match {
+
+ //
+ case _ if !inputObject.dataType.isInstanceOf[ObjectType] &&
+ !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => {
+ inputObject
+ }
+ case t if isSubtype(t, localTypeOf[Option[_]]) => {
+ val TypeRef(_, _, Seq(optType)) = t
+ val className = getClassNameFromType(optType)
+ val newPath = walkedTypePath.recordOption(className)
+ val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
+ serializerFor(unwrapped, optType, newPath, seenTypeSet)
+ }
+
+ // Since List[_] also belongs to localTypeOf[Product], we put this case before
+ // "case t if definedByConstructorParams(t)" to make sure it will match to the
+ // case "localTypeOf[Seq[_]]"
+ case t if isSubtype(t, localTypeOf[Seq[_]]) => {
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+ }
+
+ case t if isSubtype(t, localTypeOf[Array[_]]) && predefinedDt.isEmpty => {
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+ }
+
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) => {
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyClsName = getClassNameFromType(keyType)
val valueClsName = getClassNameFromType(valueType)
val keyPath = walkedTypePath.recordKeyForMap(keyClsName)
val valuePath = walkedTypePath.recordValueForMap(valueClsName)
createSerializerForMap(
- inputObject,
- MapElementInformation(
- dataTypeFor(keyType),
- nullable = !keyType.typeSymbol.asClass.isPrimitive,
- serializerFor(_, keyType, keyPath, seenTypeSet, Some(keyDT).filter(_.isInstanceOf[ComplexWrapper]))),
- MapElementInformation(
- dataTypeFor(valueType),
- nullable = !valueType.typeSymbol.asClass.isPrimitive,
- serializerFor(_, valueType, valuePath, seenTypeSet, Some(valueDT).filter(_.isInstanceOf[ComplexWrapper])))
+ inputObject,
+ MapElementInformation(
+ dataTypeFor(keyType),
+ nullable = !keyType.typeSymbol.asClass.isPrimitive,
+ serializerFor(_, keyType, keyPath, seenTypeSet)
+ ),
+ MapElementInformation(
+ dataTypeFor(valueType),
+ nullable = !valueType.typeSymbol.asClass.isPrimitive,
+ serializerFor(_, valueType, valuePath, seenTypeSet)
+ )
)
- case ArrayType(elementType, _) =>
- toCatalystArray(inputObject, getType(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass]))
+ }
- case StructType(elementType: Array[StructField]) =>
- val cls = otherTypeWrapper.cls
- val names = elementType.map(_.name)
+ case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => {
+ val TypeRef(_, _, Seq(elementType)) = t
- val beanInfo = Introspector.getBeanInfo(cls)
- val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))
+ // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
+ // Note that the property of `Set` is only kept when manipulating the data as domain object.
+ val newInput =
+ Invoke(
+ inputObject,
+ "toSeq",
+ ObjectType(classOf[Seq[_]])
+ )
+ toCatalystArray(newInput, elementType)
+ }
- val fields = elementType.map { structField =>
+ case t if isSubtype(t, localTypeOf[String]) => {
+ createSerializerForString(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) => {
+ createSerializerForJavaInstant(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => {
+ createSerializerForSqlTimestamp(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => {
+ createSerializerForLocalDateTime(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => {
+ createSerializerForJavaLocalDate(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => {
+ createSerializerForSqlDate(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Duration]) => {
+ createSerializerForJavaDuration(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Period]) => {
+ createSerializerForJavaPeriod(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[BigDecimal]) => {
+ createSerializerForScalaBigDecimal(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => {
+ createSerializerForJavaBigDecimal(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => {
+ createSerializerForJavaBigInteger(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => {
+ createSerializerForScalaBigInt(inputObject)
+ }
- val maybeProp = methods.find(it => it.getName == structField.name)
- if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}")
- val fieldName = structField.name
- val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
- val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
- val fieldValue = Invoke(
- inputObject,
- maybeProp.get.getName,
- inferExternalType(propClass),
- returnNullable = propDt.nullable
- )
- val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
- (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None))
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) => {
+ createSerializerForInteger(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Int]) => {
+ createSerializerForInteger(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) => {
+ createSerializerForLong(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Long]) => {
+ createSerializerForLong(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) => {
+ createSerializerForDouble(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Double]) => {
+ createSerializerForDouble(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) => {
+ createSerializerForFloat(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Float]) => {
+ createSerializerForFloat(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) => {
+ createSerializerForShort(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Short]) => {
+ createSerializerForShort(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) => {
+ createSerializerForByte(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Byte]) => {
+ createSerializerForByte(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => {
+ createSerializerForBoolean(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[Boolean]) => {
+ createSerializerForBoolean(inputObject)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => {
+ createSerializerForString(
+ Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)
+ )
+ }
+ case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => {
+ val udt = getClassFromType(t)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance()
+ val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+ createSerializerForUserDefinedType(inputObject, udt, udtClass)
+ }
+ case t if UDTRegistration.exists(getClassNameFromType(t)) => {
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[_]]
+ val udtClass = udt.getClass
+ createSerializerForUserDefinedType(inputObject, udt, udtClass)
+ }
+ //
+
+ case _ if predefinedDt.isDefined => {
+ predefinedDt.get match {
+
+ case dataType: KDataTypeWrapper => {
+ val cls = dataType.cls
+ val properties = getJavaBeanReadableProperties(cls)
+ val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField])
+ val fields: Array[(String, Expression)] = structFields.map { structField =>
+ val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName)
+ if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${
+ structField.name
+ } is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}"
+ )
+ val fieldName = structField.name
+ val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
+ val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
+
+ val fieldValue = Invoke(
+ inputObject,
+ maybeProp.get.getReadMethod.getName,
+ inferExternalType(propClass),
+ returnNullable = structField.nullable
+ )
+ val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
+
+ val tpe = getType(propClass)
+
+ val serializer = serializerFor(
+ inputObject = fieldValue,
+ tpe = tpe,
+ walkedTypePath = newPath,
+ seenTypeSet = seenTypeSet,
+ predefinedDt = if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None
+ )
+
+ (fieldName, serializer)
+ }
+ createSerializerForObject(inputObject, fields)
+ }
+
+ case otherTypeWrapper: ComplexWrapper => {
+
+ otherTypeWrapper.dt match {
+
+ case MapType(kt, vt, _) => {
+ val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls)
+ .map(getType(_))
+ val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass])
+ val keyClsName = getClassNameFromType(keyType)
+ val valueClsName = getClassNameFromType(valueType)
+ val keyPath = walkedTypePath.recordKeyForMap(keyClsName)
+ val valuePath = walkedTypePath.recordValueForMap(valueClsName)
+
+ createSerializerForMap(
+ inputObject,
+ MapElementInformation(
+ dataTypeFor(keyType),
+ nullable = !keyType.typeSymbol.asClass.isPrimitive,
+ serializerFor(
+ _, keyType, keyPath, seenTypeSet, Some(keyDT)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ )
+ ),
+ MapElementInformation(
+ dataTypeFor(valueType),
+ nullable = !valueType.typeSymbol.asClass.isPrimitive,
+ serializerFor(
+ _, valueType, valuePath, seenTypeSet, Some(valueDT)
+ .filter(_.isInstanceOf[ComplexWrapper])
+ )
+ )
+ )
+ }
+
+ case ArrayType(elementType, _) => {
+ toCatalystArray(
+ inputObject,
+ getType(elementType.asInstanceOf[DataTypeWithClass].cls
+ ), Some(elementType.asInstanceOf[DataTypeWithClass])
+ )
+ }
+
+ case StructType(elementType: Array[StructField]) => {
+ val cls = otherTypeWrapper.cls
+ val names = elementType.map(_.name)
+
+ val beanInfo = Introspector.getBeanInfo(cls)
+ val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))
+
+
+ val fields = elementType.map { structField =>
+
+ val maybeProp = methods.find(it => it.getName == structField.name)
+ if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${
+ structField.name
+ } is not found among available props, which are: ${
+ methods.map(_.getName).mkString(", ")
+ }"
+ )
+ val fieldName = structField.name
+ val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
+ val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
+ val fieldValue = Invoke(
+ inputObject,
+ maybeProp.get.getName,
+ inferExternalType(propClass),
+ returnNullable = propDt.nullable
+ )
+ val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
+ (fieldName, serializerFor(
+ fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt
+ .isInstanceOf[ComplexWrapper]) Some(propDt) else None
+ ))
+
+ }
+ createSerializerForObject(inputObject, fields)
+ }
+
+ case _ => {
+ throw new UnsupportedOperationException(
+ s"No Encoder found for $tpe\n" + walkedTypePath
+ )
+ }
+ }
+ }
+ }
+ }
+
+ case t if definedByConstructorParams(t) => {
+ if (seenTypeSet.contains(t)) {
+ throw new UnsupportedOperationException(
+ s"cannot have circular references in class, but got the circular reference of class $t"
+ )
+ }
+
+ val params = getConstructorParameters(t)
+ val fields = params.map { case (fieldName, fieldType) =>
+ if (javaKeywords.contains(fieldName)) {
+ throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
+ "cannot be used as field name\n" + walkedTypePath
+ )
+ }
+
+ // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
+ // is necessary here. Because for a nullable nested inputObject with struct data
+ // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
+ // for IntegerType without KnownNotNull. And that's what we do not expect to.
+ val fieldValue = Invoke(
+ KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
+ returnNullable = !fieldType.typeSymbol.asClass.isPrimitive
+ )
+ val clsName = getClassNameFromType(fieldType)
+ val newPath = walkedTypePath.recordField(clsName, fieldName)
+ (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
}
createSerializerForObject(inputObject, fields)
+ }
- case _ =>
+ case _ => {
throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath)
-
+ s"No Encoder found for $tpe\n" + walkedTypePath
+ )
}
}
+ }
- case t if definedByConstructorParams(t) =>
- if (seenTypeSet.contains(t)) {
- throw new UnsupportedOperationException(
- s"cannot have circular references in class, but got the circular reference of class $t")
- }
+ def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = {
+ Invoke(
+ path, "toString", ObjectType(classOf[java.lang.String]),
+ returnNullable = returnNullable
+ )
+ }
- val params = getConstructorParameters(t)
- val fields = params.map { case (fieldName, fieldType) =>
- if (javaKeywords.contains(fieldName)) {
- throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
- "cannot be used as field name\n" + walkedTypePath)
- }
-
- // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
- // is necessary here. Because for a nullable nested inputObject with struct data
- // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
- // for IntegerType without KnownNotNull. And that's what we do not expect to.
- val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
- returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
- val clsName = getClassNameFromType(fieldType)
- val newPath = walkedTypePath.recordField(clsName, fieldName)
- (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
- }
- createSerializerForObject(inputObject, fields)
+ def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
+ val beanInfo = Introspector.getBeanInfo(beanClass)
+ beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ .filterNot(_.getName == "declaringClass")
+ .filter(_.getReadMethod != null)
+ }
+
+ /*
+ * Retrieves the runtime class corresponding to the provided type.
+ */
+ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
+
+ case class Schema(dataType: DataType, nullable: Boolean)
+
+ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
+
+ baseType(tpe) match {
+ // this must be the first case, since all objects in scala are instances of Null, therefore
+ // Null type would wrongly match the first of them, which is Option as of now
+ case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true)
- case _ =>
- throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath)
+ case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => {
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
+ getConstructor().newInstance()
+ Schema(udt, nullable = true)
+ }
+ case t if UDTRegistration.exists(getClassNameFromType(t)) => {
+ val udt = UDTRegistration
+ .getUDTFor(getClassNameFromType(t))
+ .get
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ Schema(udt, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Option[_]]) => {
+ val TypeRef(_, _, Seq(optType)) = t
+ Schema(schemaFor(optType).dataType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => {
+ Schema(BinaryType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Array[_]]) => {
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Seq[_]]) => {
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) => {
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ Schema(
+ MapType(
+ schemaFor(keyType).dataType,
+ valueDataType, valueContainsNull = valueNullable
+ ), nullable = true
+ )
+ }
+ case t if isSubtype(t, localTypeOf[Set[_]]) => {
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[String]) => {
+ Schema(StringType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) => {
+ Schema(TimestampType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => {
+ Schema(TimestampType, nullable = true)
+ }
+ // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes.
+ case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) && Utils.isTesting => {
+ Schema(TimestampNTZType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => {
+ Schema(DateType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => {
+ Schema(DateType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) => {
+ Schema(CalendarIntervalType, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Duration]) => {
+ Schema(DayTimeIntervalType(), nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.time.Period]) => {
+ Schema(YearMonthIntervalType(), nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[BigDecimal]) => {
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => {
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => {
+ Schema(DecimalType.BigIntDecimal, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => {
+ Schema(DecimalType.BigIntDecimal, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[Decimal]) => {
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ }
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true)
+ case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false)
+ case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false)
+ case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false)
+ case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false)
+ case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false)
+ case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false)
+ case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false)
+ case t if definedByConstructorParams(t) => {
+ val params = getConstructorParameters(t)
+ Schema(
+ StructType(
+ params.map { case (fieldName, fieldType) =>
+ val Schema(dataType, nullable) = schemaFor(fieldType)
+ StructField(fieldName, dataType, nullable)
+ }
+ ), nullable = true
+ )
+ }
+ case other => {
+ throw new UnsupportedOperationException(s"Schema for type $other is not supported")
+ }
+ }
}
- }
-
- def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = {
- Invoke(path, "toString", ObjectType(classOf[java.lang.String]),
- returnNullable = returnNullable)
- }
-
- def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
- val beanInfo = Introspector.getBeanInfo(beanClass)
- beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
- .filterNot(_.getName == "declaringClass")
- .filter(_.getReadMethod != null)
- }
-
- /*
- * Retrieves the runtime class corresponding to the provided type.
- */
- def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
-
- case class Schema(dataType: DataType, nullable: Boolean)
-
- /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
- baseType(tpe) match {
- // this must be the first case, since all objects in scala are instances of Null, therefore
- // Null type would wrongly match the first of them, which is Option as of now
- case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true)
- case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
- getConstructor().newInstance()
- Schema(udt, nullable = true)
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[_]]
- Schema(udt, nullable = true)
- case t if isSubtype(t, localTypeOf[Option[_]]) =>
- val TypeRef(_, _, Seq(optType)) = t
- Schema(schemaFor(optType).dataType, nullable = true)
- case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, nullable = true)
- case t if isSubtype(t, localTypeOf[Array[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if isSubtype(t, localTypeOf[Seq[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- Schema(MapType(schemaFor(keyType).dataType,
- valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if isSubtype(t, localTypeOf[Set[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- Schema(TimestampType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- Schema(TimestampType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- Schema(DecimalType.BigIntDecimal, nullable = true)
- case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- Schema(DecimalType.BigIntDecimal, nullable = true)
- case t if isSubtype(t, localTypeOf[Decimal]) =>
- Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true)
- case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true)
- case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false)
- case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false)
- case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false)
- case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false)
- case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false)
- case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false)
- case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false)
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- Schema(StructType(
- params.map { case (fieldName, fieldType) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- StructField(fieldName, dataType, nullable)
- }), nullable = true)
- case other =>
- throw new UnsupportedOperationException(s"Schema for type $other is not supported")
+
+ /**
+ * Whether the fields of the given type is defined entirely by its constructor parameters.
+ */
+ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
+ tpe.dealias match {
+ // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type.
+ case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head)
+ case _ => {
+ isSubtype(tpe.dealias, localTypeOf[Product]) ||
+ isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
+ }
+ }
}
- }
-
- /**
- * Whether the fields of the given type is defined entirely by its constructor parameters.
- */
- def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
- tpe.dealias match {
- // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type.
- case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head)
- case _ => isSubtype(tpe.dealias, localTypeOf[Product]) ||
- isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
+
+ private val javaKeywords = Set(
+ "abstract", "assert", "boolean", "break", "byte", "case", "catch",
+ "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false",
+ "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int",
+ "interface", "long", "native", "new", "null", "package", "private", "protected", "public",
+ "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
+ "throws", "transient", "true", "try", "void", "volatile", "while"
+ )
+
+
+ @scala.annotation.tailrec
+ def javaBoxedType(dt: DataType): Class[_] = dt match {
+ case _: DecimalType => classOf[Decimal]
+ case _: DayTimeIntervalType => classOf[java.lang.Long]
+ case _: YearMonthIntervalType => classOf[java.lang.Integer]
+ case BinaryType => classOf[Array[Byte]]
+ case StringType => classOf[UTF8String]
+ case CalendarIntervalType => classOf[CalendarInterval]
+ case _: StructType => classOf[InternalRow]
+ case _: ArrayType => classOf[ArrayType]
+ case _: MapType => classOf[MapType]
+ case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
+ case ObjectType(cls) => cls
+ case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
}
- }
-
- private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
- "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false",
- "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int",
- "interface", "long", "native", "new", "null", "package", "private", "protected", "public",
- "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
- "throws", "transient", "true", "try", "void", "volatile", "while")
-
-
- @scala.annotation.tailrec
- def javaBoxedType(dt: DataType): Class[_] = dt match {
- case _: DecimalType => classOf[Decimal]
- case BinaryType => classOf[Array[Byte]]
- case StringType => classOf[UTF8String]
- case CalendarIntervalType => classOf[CalendarInterval]
- case _: StructType => classOf[InternalRow]
- case _: ArrayType => classOf[ArrayType]
- case _: MapType => classOf[MapType]
- case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
- case ObjectType(cls) => cls
- case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
- }
}
@@ -991,120 +1278,124 @@ object KotlinReflection extends KotlinReflection {
* object, this trait able to work in both the runtime and the compile time (macro) universe.
*/
trait KotlinReflection extends Logging {
- /** The universe we work in (runtime or macro) */
- val universe: scala.reflect.api.Universe
-
- /** The mirror used to access types in the universe */
- def mirror: universe.Mirror
-
- import universe._
-
- // The Predef.Map is scala.collection.immutable.Map.
- // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
-
- /**
- * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
- * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
- * `scala.reflect.runtime.JavaUniverse.undoLog`.
- *
- * @see https://github.com/scala/bug/issues/8302
- */
- def cleanUpReflectionObjects[T](func: => T): T = {
- universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
- }
-
- /**
- * Return the Scala Type for `T` in the current classloader mirror.
- *
- * Use this method instead of the convenience method `universe.typeOf`, which
- * assumes that all types can be found in the classloader that loaded scala-reflect classes.
- * That's not necessarily the case when running using Eclipse launchers or even
- * Sbt console or test (without `fork := true`).
- *
- * @see SPARK-5281
- */
- def localTypeOf[T: TypeTag]: `Type` = {
- val tag = implicitly[TypeTag[T]]
- tag.in(mirror).tpe.dealias
- }
-
- /**
- * Returns the full class name for a type. The returned name is the canonical
- * Scala name, where each component is separated by a period. It is NOT the
- * Java-equivalent runtime name (no dollar signs).
- *
- * In simple cases, both the Scala and Java names are the same, however when Scala
- * generates constructs that do not map to a Java equivalent, such as singleton objects
- * or nested classes in package objects, it uses the dollar sign ($) to create
- * synthetic classes, emulating behaviour in Java bytecode.
- */
- def getClassNameFromType(tpe: `Type`): String = {
- tpe.dealias.erasure.typeSymbol.asClass.fullName
- }
-
- /**
- * Returns the parameter names and types for the primary constructor of this type.
- *
- * Note that it only works for scala classes with primary constructor, and currently doesn't
- * support inner class.
- */
- def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
- val dealiasedTpe = tpe.dealias
- val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
- val params = constructParams(dealiasedTpe)
- // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
- if (actualTypeArgs.nonEmpty) {
- params.map { p =>
- p.name.decodedName.toString ->
- p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- }
- } else {
- params.map { p =>
- p.name.decodedName.toString -> p.typeSignature
- }
+ /** The universe we work in (runtime or macro) */
+ val universe: scala.reflect.api.Universe
+
+ /** The mirror used to access types in the universe */
+ def mirror: universe.Mirror
+
+ import universe._
+
+ // The Predef.Map is scala.collection.immutable.Map.
+ // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+
+ /**
+ * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
+ * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
+ * `scala.reflect.runtime.JavaUniverse.undoLog`.
+ *
+ * @see https://github.com/scala/bug/issues/8302
+ */
+ def cleanUpReflectionObjects[T](func: => T): T = {
+ universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
}
- }
-
- /**
- * If our type is a Scala trait it may have a companion object that
- * only defines a constructor via `apply` method.
- */
- private def getCompanionConstructor(tpe: Type): Symbol = {
- def throwUnsupportedOperation = {
- throw new UnsupportedOperationException(s"Unable to find constructor for $tpe. " +
- s"This could happen if $tpe is an interface, or a trait without companion object " +
- "constructor.")
+
+ /**
+ * Return the Scala Type for `T` in the current classloader mirror.
+ *
+ * Use this method instead of the convenience method `universe.typeOf`, which
+ * assumes that all types can be found in the classloader that loaded scala-reflect classes.
+ * That's not necessarily the case when running using Eclipse launchers or even
+ * Sbt console or test (without `fork := true`).
+ *
+ * @see SPARK-5281
+ */
+ def localTypeOf[T: TypeTag]: `Type` = {
+ val tag = implicitly[TypeTag[T]]
+ tag.in(mirror).tpe.dealias
}
- tpe.typeSymbol.asClass.companion match {
- case NoSymbol => throwUnsupportedOperation
- case sym => sym.asTerm.typeSignature.member(universe.TermName("apply")) match {
- case NoSymbol => throwUnsupportedOperation
- case constructorSym => constructorSym
- }
+ /**
+ * Returns the full class name for a type. The returned name is the canonical
+ * Scala name, where each component is separated by a period. It is NOT the
+ * Java-equivalent runtime name (no dollar signs).
+ *
+ * In simple cases, both the Scala and Java names are the same, however when Scala
+ * generates constructs that do not map to a Java equivalent, such as singleton objects
+ * or nested classes in package objects, it uses the dollar sign ($) to create
+ * synthetic classes, emulating behaviour in Java bytecode.
+ */
+ def getClassNameFromType(tpe: `Type`): String = {
+ tpe.dealias.erasure.typeSymbol.asClass.fullName
}
- }
- protected def constructParams(tpe: Type): Seq[Symbol] = {
- val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match {
- case NoSymbol => getCompanionConstructor(tpe)
- case sym => sym
+ /**
+ * Returns the parameter names and types for the primary constructor of this type.
+ *
+ * Note that it only works for scala classes with primary constructor, and currently doesn't
+ * support inner class.
+ */
+ def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
+ val dealiasedTpe = tpe.dealias
+ val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
+ val params = constructParams(dealiasedTpe)
+ // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
+ if (actualTypeArgs.nonEmpty) {
+ params.map { p =>
+ p.name.decodedName.toString ->
+ p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ }
+ } else {
+ params.map { p =>
+ p.name.decodedName.toString -> p.typeSignature
+ }
+ }
}
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramLists
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
- s => s.isMethod && s.asMethod.isPrimaryConstructor)
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramLists
- }
+
+ /**
+ * If our type is a Scala trait it may have a companion object that
+ * only defines a constructor via `apply` method.
+ */
+ private def getCompanionConstructor(tpe: Type): Symbol = {
+ def throwUnsupportedOperation = {
+ throw new UnsupportedOperationException(s"Unable to find constructor for $tpe. " +
+ s"This could happen if $tpe is an interface, or a trait without companion object " +
+ "constructor."
+ )
+ }
+
+ tpe.typeSymbol.asClass.companion match {
+ case NoSymbol => throwUnsupportedOperation
+ case sym => {
+ sym.asTerm.typeSignature.member(universe.TermName("apply")) match {
+ case NoSymbol => throwUnsupportedOperation
+ case constructorSym => constructorSym
+ }
+ }
+ }
+ }
+
+ protected def constructParams(tpe: Type): Seq[Symbol] = {
+ val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match {
+ case NoSymbol => getCompanionConstructor(tpe)
+ case sym => sym
+ }
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramLists
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
+ s => s.isMethod && s.asMethod.isPrimaryConstructor
+ )
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramLists
+ }
+ }
+ params.flatten
}
- params.flatten
- }
}
diff --git a/examples/pom-3.2_2.12.xml b/examples/pom-3.2_2.12.xml
index b5267352..5f214b69 100644
--- a/examples/pom-3.2_2.12.xml
+++ b/examples/pom-3.2_2.12.xml
@@ -24,6 +24,11 @@
spark-sql_${scala.compat.version}
${spark3.version}
+
+ org.apache.spark
+ spark-streaming_${scala.compat.version}
+ ${spark3.version}
+
diff --git a/kotlin-spark-api/3.2/pom_2.12.xml b/kotlin-spark-api/3.2/pom_2.12.xml
index 756d9c2b..826547d2 100644
--- a/kotlin-spark-api/3.2/pom_2.12.xml
+++ b/kotlin-spark-api/3.2/pom_2.12.xml
@@ -36,6 +36,12 @@
${spark3.version}
provided
+
+ org.apache.spark
+ spark-streaming_${scala.compat.version}
+ ${spark3.version}
+ provided
+
diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt
index 1061a21a..fb1b5340 100644
--- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt
+++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt
@@ -21,11 +21,11 @@
package org.jetbrains.kotlinx.spark.api
-import org.apache.hadoop.shaded.org.apache.commons.math3.exception.util.ArgUtils
import org.apache.spark.SparkContext
-import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.api.java.*
import org.apache.spark.api.java.function.*
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.*
import org.apache.spark.sql.Encoders.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -33,17 +33,21 @@ import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
+import org.apache.spark.unsafe.types.CalendarInterval
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import scala.Product
import scala.Tuple2
+import scala.concurrent.duration.`Duration$`
import scala.reflect.ClassTag
-import scala.reflect.api.TypeTags.TypeTag
+import scala.reflect.api.StandardDefinitions
import java.beans.PropertyDescriptor
import java.math.BigDecimal
import java.sql.Date
import java.sql.Timestamp
+import java.time.Duration
import java.time.Instant
import java.time.LocalDate
+import java.time.Period
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.Any
@@ -95,10 +99,12 @@ val ENCODERS: Map, Encoder<*>> = mapOf(
String::class to STRING(),
BigDecimal::class to DECIMAL(),
Date::class to DATE(),
- LocalDate::class to LOCALDATE(), // 3.0 only
+ LocalDate::class to LOCALDATE(), // 3.0+
Timestamp::class to TIMESTAMP(),
- Instant::class to INSTANT(), // 3.0 only
- ByteArray::class to BINARY()
+ Instant::class to INSTANT(), // 3.0+
+ ByteArray::class to BINARY(),
+ Duration::class to DURATION(), // 3.2+
+ Period::class to PERIOD(), // 3.2+
)
@@ -154,6 +160,18 @@ inline fun SparkSession.dsOf(vararg t: T): Dataset =
inline fun List.toDS(spark: SparkSession): Dataset =
spark.createDataset(this, encoder())
+/**
+ * Utility method to create dataset from RDD
+ */
+inline fun RDD.toDS(spark: SparkSession): Dataset =
+ spark.createDataset(this, encoder())
+
+/**
+ * Utility method to create dataset from JavaRDD
+ */
+inline fun JavaRDDLike.toDS(spark: SparkSession): Dataset =
+ spark.createDataset(this.rdd(), encoder())
+
/**
* Main method of API, which gives you seamless integration with Spark:
* It creates encoder for any given supported type T
@@ -177,12 +195,16 @@ fun generateEncoder(type: KType, cls: KClass<*>): Encoder {
} as Encoder
}
-private fun isSupportedClass(cls: KClass<*>): Boolean =
- cls.isData
- || cls.isSubclassOf(Map::class)
- || cls.isSubclassOf(Iterable::class)
- || cls.isSubclassOf(Product::class)
- || cls.java.isArray
+private fun isSupportedClass(cls: KClass<*>): Boolean = when {
+ cls == ByteArray::class -> false // uses binary encoder
+ cls.isData -> true
+ cls.isSubclassOf(Map::class) -> true
+ cls.isSubclassOf(Iterable::class) -> true
+ cls.isSubclassOf(Product::class) -> true
+ cls.java.isArray -> true
+ else -> false
+ }
+
private fun kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder {
return ExpressionEncoder(
@@ -1192,7 +1214,7 @@ fun schema(type: KType, map: Map = mapOf()): DataType {
DoubleArray::class -> typeOf()
BooleanArray::class -> typeOf()
ShortArray::class -> typeOf()
- ByteArray::class -> typeOf()
+// ByteArray::class -> typeOf() handled by BinaryType
else -> types.getValue(klass.typeParameters[0].name)
}
} else types.getValue(klass.typeParameters[0].name)
@@ -1290,10 +1312,14 @@ private val knownDataTypes: Map, DataType> = mapOf(
Float::class to DataTypes.FloatType,
Double::class to DataTypes.DoubleType,
String::class to DataTypes.StringType,
- LocalDate::class to `DateType$`.`MODULE$`,
- Date::class to `DateType$`.`MODULE$`,
- Timestamp::class to `TimestampType$`.`MODULE$`,
- Instant::class to `TimestampType$`.`MODULE$`,
+ LocalDate::class to DataTypes.DateType,
+ Date::class to DataTypes.DateType,
+ Timestamp::class to DataTypes.TimestampType,
+ Instant::class to DataTypes.TimestampType,
+ ByteArray::class to DataTypes.BinaryType,
+ Decimal::class to DecimalType.SYSTEM_DEFAULT(),
+ BigDecimal::class to DecimalType.SYSTEM_DEFAULT(),
+ CalendarInterval::class to DataTypes.CalendarIntervalType,
)
private fun transitiveMerge(a: Map, b: Map): Map {
diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
index 6188daae..98fdae8d 100644
--- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
+++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt
@@ -20,6 +20,11 @@
package org.jetbrains.kotlinx.spark.api
import org.apache.spark.SparkConf
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaRDDLike
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
import org.apache.spark.sql.SparkSession.Builder
import org.apache.spark.sql.UDFRegistration
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
@@ -78,18 +83,38 @@ inline fun withSpark(builder: Builder, logLevel: SparkLogLevel = ERROR, func: KS
KSparkSession(this).apply {
sparkContext.setLogLevel(logLevel)
func()
+ spark.stop()
}
}
- .also { it.stop() }
+}
+
+/**
+ * Wrapper for spark creation which copies params from [sparkConf].
+ *
+ * @param sparkConf Sets a list of config options based on this.
+ * @param logLevel Control our logLevel. This overrides any user-defined log settings.
+ * @param func function which will be executed in context of [KSparkSession] (it means that `this` inside block will point to [KSparkSession])
+ */
+@JvmOverloads
+inline fun withSpark(sparkConf: SparkConf, logLevel: SparkLogLevel = ERROR, func: KSparkSession.() -> Unit) {
+ withSpark(
+ builder = SparkSession.builder().config(sparkConf),
+ logLevel = logLevel,
+ func = func,
+ )
}
/**
* This wrapper over [SparkSession] which provides several additional methods to create [org.apache.spark.sql.Dataset]
*/
-@Suppress("EXPERIMENTAL_FEATURE_WARNING", "unused")
-inline class KSparkSession(val spark: SparkSession) {
+class KSparkSession(val spark: SparkSession) {
+
+ val sc: JavaSparkContext by lazy { JavaSparkContext(spark.sparkContext) }
+
inline fun List.toDS() = toDS(spark)
inline fun Array.toDS() = spark.dsOf(*this)
inline fun dsOf(vararg arg: T) = spark.dsOf(*arg)
+ inline fun RDD.toDS() = toDS(spark)
+ inline fun JavaRDDLike.toDS() = toDS(spark)
val udf: UDFRegistration get() = spark.udf()
}
diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt
index a4b2bdd7..af870038 100644
--- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt
+++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt
@@ -22,32 +22,34 @@
*/
package org.jetbrains.kotlinx.spark.api
-data class Arity1(val _1: T1)
-data class Arity2(val _1: T1, val _2: T2)
-data class Arity3(val _1: T1, val _2: T2, val _3: T3)
-data class Arity4(val _1: T1, val _2: T2, val _3: T3, val _4: T4)
-data class Arity5(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5)
-data class Arity6(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6)
-data class Arity7(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7)
-data class Arity8(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8)
-data class Arity9(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9)
-data class Arity10(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10)
-data class Arity11(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11)
-data class Arity12(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12)
-data class Arity13(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13)
-data class Arity14(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14)
-data class Arity15(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15)
-data class Arity16(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16)
-data class Arity17(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17)
-data class Arity18(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18)
-data class Arity19(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19)
-data class Arity20(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20)
-data class Arity21(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21)
-data class Arity22(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22)
-data class Arity23(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23)
-data class Arity24(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24)
-data class Arity25(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25)
-data class Arity26(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25, val _26: T26)
+import java.io.Serializable
+
+data class Arity1(val _1: T1): Serializable
+data class Arity2(val _1: T1, val _2: T2): Serializable
+data class Arity3(val _1: T1, val _2: T2, val _3: T3): Serializable
+data class Arity4(val _1: T1, val _2: T2, val _3: T3, val _4: T4): Serializable
+data class Arity5(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5): Serializable
+data class Arity6(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6): Serializable
+data class Arity7(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7): Serializable
+data class Arity8(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8): Serializable
+data class Arity9(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9): Serializable
+data class Arity10(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10): Serializable
+data class Arity11(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11): Serializable
+data class Arity12(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12): Serializable
+data class Arity13(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13): Serializable
+data class Arity14(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14): Serializable
+data class Arity15(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15): Serializable
+data class Arity16(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16): Serializable
+data class Arity17(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17): Serializable
+data class Arity18(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18): Serializable
+data class Arity19(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19): Serializable
+data class Arity20(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20): Serializable
+data class Arity21(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21): Serializable
+data class Arity22(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22): Serializable
+data class Arity23(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23): Serializable
+data class Arity24(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24): Serializable
+data class Arity25(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25): Serializable
+data class Arity26(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25, val _26: T26): Serializable
fun c(_1: T1) = Arity1(_1)
fun c(_1: T1, _2: T2) = Arity2(_1, _2)
fun c(_1: T1, _2: T2, _3: T3) = Arity3(_1, _2, _3)
diff --git a/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt
index ed784b13..fa320c39 100644
--- a/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt
+++ b/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt
@@ -20,21 +20,32 @@ package org.jetbrains.kotlinx.spark.api/*-
import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.api.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
+import io.kotest.matchers.should
import io.kotest.matchers.shouldBe
+import org.apache.spark.api.java.JavaDoubleRDD
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.*
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
+import org.apache.spark.sql.types.Decimal
+import org.apache.spark.unsafe.types.CalendarInterval
import scala.Product
import scala.Tuple1
import scala.Tuple2
import scala.Tuple3
import scala.collection.Seq
import java.io.Serializable
+import java.math.BigDecimal
import java.sql.Date
import java.sql.Timestamp
+import java.time.Duration
import java.time.Instant
import java.time.LocalDate
+import java.time.Period
import kotlin.collections.Iterator
import scala.collection.Iterator as ScalaIterator
import scala.collection.Map as ScalaMap
@@ -318,24 +329,96 @@ class ApiTest : ShouldSpec({
cogrouped.count() shouldBe 4
}
should("handle LocalDate Datasets") { // uses encoder
- val dataset: Dataset = dsOf(LocalDate.now(), LocalDate.now())
- dataset.show()
+ val dates = listOf(LocalDate.now(), LocalDate.now())
+ val dataset: Dataset = dates.toDS()
+ dataset.collectAsList() shouldBe dates
}
should("handle Instant Datasets") { // uses encoder
- val dataset: Dataset = dsOf(Instant.now(), Instant.now())
- dataset.show()
+ val instants = listOf(Instant.now(), Instant.now())
+ val dataset: Dataset = instants.toDS()
+ dataset.collectAsList() shouldBe instants
+ }
+ should("Be able to serialize Instant") { // uses knownDataTypes
+ val instantPair = Instant.now() to Instant.now()
+ val dataset = dsOf(instantPair)
+ dataset.collectAsList() shouldBe listOf(instantPair)
}
should("be able to serialize Date") { // uses knownDataTypes
- val dataset: Dataset> = dsOf(Date.valueOf("2020-02-10") to 5)
- dataset.show()
+ val datePair = Date.valueOf("2020-02-10") to 5
+ val dataset: Dataset> = dsOf(datePair)
+ dataset.collectAsList() shouldBe listOf(datePair)
}
should("handle Timestamp Datasets") { // uses encoder
- val dataset = dsOf(Timestamp(0L))
- dataset.show()
+ val timeStamps = listOf(Timestamp(0L), Timestamp(1L))
+ val dataset = timeStamps.toDS()
+ dataset.collectAsList() shouldBe timeStamps
}
should("be able to serialize Timestamp") { // uses knownDataTypes
- val dataset = dsOf(Timestamp(0L) to 2)
- dataset.show()
+ val timestampPair = Timestamp(0L) to 2
+ val dataset = dsOf(timestampPair)
+ dataset.collectAsList() shouldBe listOf(timestampPair)
+ }
+ should("handle Duration Datasets") { // uses encoder
+ val dataset = dsOf(Duration.ZERO)
+ dataset.collectAsList() shouldBe listOf(Duration.ZERO)
+ }
+ should("handle Period Datasets") { // uses encoder
+ val periods = listOf(Period.ZERO, Period.ofDays(2))
+ val dataset = periods.toDS()
+
+ dataset.show(false)
+
+ dataset.collectAsList().let {
+ it[0] shouldBe Period.ZERO
+
+ // NOTE Spark truncates java.time.Period to months.
+ it[1] shouldBe Period.ofDays(0)
+ }
+
+ }
+ should("handle binary datasets") { // uses encoder
+ val byteArray = "Hello there".encodeToByteArray()
+ val dataset = dsOf(byteArray)
+ dataset.collectAsList() shouldBe listOf(byteArray)
+ }
+ should("be able to serialize binary") { // uses knownDataTypes
+ val byteArrayTriple = c("Hello there".encodeToByteArray(), 1, intArrayOf(1, 2, 3))
+ val dataset = dsOf(byteArrayTriple)
+
+ val (a, b, c) = dataset.collectAsList().single()
+ a contentEquals "Hello there".encodeToByteArray() shouldBe true
+ b shouldBe 1
+ c contentEquals intArrayOf(1, 2, 3) shouldBe true
+ }
+ should("be able to serialize Decimal") { // uses knownDataTypes
+ val decimalPair = c(Decimal().set(50), 12)
+ val dataset = dsOf(decimalPair)
+ dataset.collectAsList() shouldBe listOf(decimalPair)
+ }
+ should("handle BigDecimal datasets") { // uses encoder
+ val decimals = listOf(BigDecimal.ONE, BigDecimal.TEN)
+ val dataset = decimals.toDS()
+ dataset.collectAsList().let { (one, ten) ->
+ one.compareTo(BigDecimal.ONE) shouldBe 0
+ ten.compareTo(BigDecimal.TEN) shouldBe 0
+ }
+ }
+ should("be able to serialize BigDecimal") { // uses knownDataTypes
+ val decimalPair = c(BigDecimal.TEN, 12)
+ val dataset = dsOf(decimalPair)
+ val (a, b) = dataset.collectAsList().single()
+ a.compareTo(BigDecimal.TEN) shouldBe 0
+ b shouldBe 12
+ }
+ should("be able to serialize CalendarInterval") { // uses knownDataTypes
+ val calendarIntervalPair = CalendarInterval(1, 0, 0L) to 2
+ val dataset = dsOf(calendarIntervalPair)
+ dataset.collectAsList() shouldBe listOf(calendarIntervalPair)
+ }
+ should("handle nullable datasets") {
+ val ints = listOf(1, 2, 3, null)
+ val dataset = ints.toDS()
+ dataset.collectAsList() shouldBe ints
}
should("Be able to serialize Scala Tuples including data classes") {
val dataset = dsOf(
@@ -366,20 +449,20 @@ class ApiTest : ShouldSpec({
val newDS1WithAs: Dataset = dataset.selectTyped(
col("a").`as`(),
)
- newDS1WithAs.show()
+ newDS1WithAs.collectAsList()
val newDS2: Dataset> = dataset.selectTyped(
col(SomeClass::a), // NOTE: this only works on 3.0, returning a data class with an array in it
col(SomeClass::b),
)
- newDS2.show()
+ newDS2.collectAsList()
val newDS3: Dataset> = dataset.selectTyped(
col(SomeClass::a),
col(SomeClass::b),
col(SomeClass::b),
)
- newDS3.show()
+ newDS3.collectAsList()
val newDS4: Dataset> = dataset.selectTyped(
col(SomeClass::a),
@@ -387,7 +470,7 @@ class ApiTest : ShouldSpec({
col(SomeClass::b),
col(SomeClass::b),
)
- newDS4.show()
+ newDS4.collectAsList()
val newDS5: Dataset> = dataset.selectTyped(
col(SomeClass::a),
@@ -396,7 +479,7 @@ class ApiTest : ShouldSpec({
col(SomeClass::b),
col(SomeClass::b),
)
- newDS5.show()
+ newDS5.collectAsList()
}
should("Access columns using invoke on datasets") {
val dataset = dsOf(
@@ -449,19 +532,18 @@ class ApiTest : ShouldSpec({
dataset(SomeOtherClass::a),
col(SomeOtherClass::c),
)
- b.show()
+ b.collectAsList()
}
should("Handle some where queries using column operator functions") {
val dataset = dsOf(
SomeOtherClass(intArrayOf(1, 2, 3), 4, true),
SomeOtherClass(intArrayOf(4, 3, 2), 1, true),
)
- dataset.show()
+ dataset.collectAsList()
val column = col("b").`as`()
val b = dataset.where(column gt 3 and col(SomeOtherClass::c))
- b.show()
b.count() shouldBe 1
}
@@ -470,21 +552,51 @@ class ApiTest : ShouldSpec({
listOf(SomeClass(intArrayOf(1, 2, 3), 4)),
listOf(SomeClass(intArrayOf(3, 2, 1), 0)),
)
- dataset.show()
+
+ val (first, second) = dataset.collectAsList()
+
+ first.single().let { (a, b) ->
+ a.contentEquals(intArrayOf(1, 2, 3)) shouldBe true
+ b shouldBe 4
+ }
+ second.single().let { (a, b) ->
+ a.contentEquals(intArrayOf(3, 2, 1)) shouldBe true
+ b shouldBe 0
+ }
}
should("Be able to serialize arrays of data classes") {
val dataset = dsOf(
arrayOf(SomeClass(intArrayOf(1, 2, 3), 4)),
arrayOf(SomeClass(intArrayOf(3, 2, 1), 0)),
)
- dataset.show()
+
+ val (first, second) = dataset.collectAsList()
+
+ first.single().let { (a, b) ->
+ a.contentEquals(intArrayOf(1, 2, 3)) shouldBe true
+ b shouldBe 4
+ }
+ second.single().let { (a, b) ->
+ a.contentEquals(intArrayOf(3, 2, 1)) shouldBe true
+ b shouldBe 0
+ }
}
should("Be able to serialize lists of tuples") {
val dataset = dsOf(
listOf(Tuple2(intArrayOf(1, 2, 3), 4)),
listOf(Tuple2(intArrayOf(3, 2, 1), 0)),
)
- dataset.show()
+
+ val (first, second) = dataset.collectAsList()
+
+ first.single().let {
+ it._1().contentEquals(intArrayOf(1, 2, 3)) shouldBe true
+ it._2() shouldBe 4
+ }
+ second.single().let {
+ it._1().contentEquals(intArrayOf(3, 2, 1)) shouldBe true
+ it._2() shouldBe 0
+ }
}
should("Allow simple forEachPartition in datasets") {
val dataset = dsOf(
@@ -593,10 +705,73 @@ class ApiTest : ShouldSpec({
it.nullable() shouldBe true
}
}
+ should("Convert Scala RDD to Dataset") {
+ val rdd0: RDD = sc.parallelize(
+ listOf(1, 2, 3, 4, 5, 6)
+ ).rdd()
+ val dataset0: Dataset = rdd0.toDS()
+
+ dataset0.toList() shouldBe listOf(1, 2, 3, 4, 5, 6)
+ }
+
+ should("Convert a JavaRDD to a Dataset") {
+ val rdd1: JavaRDD = sc.parallelize(
+ listOf(1, 2, 3, 4, 5, 6)
+ )
+ val dataset1: Dataset = rdd1.toDS()
+
+ dataset1.toList() shouldBe listOf(1, 2, 3, 4, 5, 6)
+ }
+ should("Convert JavaDoubleRDD to Dataset") {
+
+ // JavaDoubleRDD
+ val rdd2: JavaDoubleRDD = sc.parallelizeDoubles(
+ listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
+ )
+ val dataset2: Dataset = rdd2.toDS()
+
+ dataset2.toList() shouldBe listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)
+ }
+ should("Convert JavaPairRDD to Dataset") {
+ val rdd3: JavaPairRDD = sc.parallelizePairs(
+ listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0))
+ )
+ val dataset3: Dataset> = rdd3.toDS()
+
+ dataset3.toList>() shouldBe listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0))
+ }
+ should("Convert Kotlin Serializable data class RDD to Dataset") {
+ val rdd4 = sc.parallelize(
+ listOf(SomeClass(intArrayOf(1, 2), 0))
+ )
+ val dataset4 = rdd4.toDS()
+
+ dataset4.toList().first().let { (a, b) ->
+ a contentEquals intArrayOf(1, 2) shouldBe true
+ b shouldBe 0
+ }
+ }
+ should("Convert Arity RDD to Dataset") {
+ val rdd5 = sc.parallelize(
+ listOf(c(1.0, 4))
+ )
+ val dataset5 = rdd5.toDS()
+
+ dataset5.toList>() shouldBe listOf(c(1.0, 4))
+ }
+ should("Convert List RDD to Dataset") {
+ val rdd6 = sc.parallelize(
+ listOf(listOf(1, 2, 3), listOf(4, 5, 6))
+ )
+ val dataset6 = rdd6.toDS()
+
+ dataset6.toList>() shouldBe listOf(listOf(1, 2, 3), listOf(4, 5, 6))
+ }
}
}
})
+
data class DataClassWithTuple(val tuple: T)
data class LonLat(val lon: Double, val lat: Double)
@@ -626,5 +801,5 @@ data class ComplexEnumDataClass(
data class NullFieldAbleDataClass(
val optionList: List?,
- val optionMap: Map?
-)
\ No newline at end of file
+ val optionMap: Map?,
+)
diff --git a/pom.xml b/pom.xml
index 47043737..0df3adac 100644
--- a/pom.xml
+++ b/pom.xml
@@ -15,7 +15,7 @@
0.16.0
4.6.0
1.0.1
- 3.2.0
+ 3.2.1
2.10.0
@@ -32,6 +32,7 @@
3.0.0-M5
1.6.8
4.5.6
+ official