diff --git a/.github/workflows/generate_docs.yml b/.github/workflows/generate_docs.yml index 0aee6d85..cdbd1949 100644 --- a/.github/workflows/generate_docs.yml +++ b/.github/workflows/generate_docs.yml @@ -25,5 +25,6 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_branch: docs publish_dir: ./kotlin-spark-api/3.2/target/dokka + force_orphan: true diff --git a/README.md b/README.md index 380ee310..498334d2 100644 --- a/README.md +++ b/README.md @@ -27,14 +27,14 @@ We have opened a Spark Project Improvement Proposal: [Kotlin support for Apache - [Code of Conduct](#code-of-conduct) - [License](#license) -## Supported versions of Apache Spark #TODO +## Supported versions of Apache Spark | Apache Spark | Scala | Kotlin for Apache Spark | |:------------:|:-----:|:-------------------------------:| | 3.0.0+ | 2.12 | kotlin-spark-api-3.0:1.0.2 | | 2.4.1+ | 2.12 | kotlin-spark-api-2.4_2.12:1.0.2 | | 2.4.1+ | 2.11 | kotlin-spark-api-2.4_2.11:1.0.2 | -| 3.2.0+ | 2.12 | kotlin-spark-api-2.4_2.12:1.0.3 | +| 3.2.0+ | 2.12 | kotlin-spark-api-3.2:1.0.3 | ## Releases diff --git a/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala b/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala index be808af0..05ff330b 100644 --- a/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala +++ b/core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.ScalaReflection.{Schema, dataTypeFor, getClassFromType, isSubtype, javaBoxedType, localTypeOf} import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ @@ -30,8 +31,10 @@ import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePath} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils import java.beans.{Introspector, PropertyDescriptor} +import java.lang.Exception /** @@ -45,944 +48,1228 @@ trait DefinedByConstructorParams * KotlinReflection is heavily inspired by ScalaReflection and even extends it just to add several methods */ object KotlinReflection extends KotlinReflection { - /** - * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping - * to a native type, an ObjectType is returned. - * - * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers. - */ - private def inferExternalType(cls: Class[_]): DataType = cls match { - case c if c == java.lang.Boolean.TYPE => BooleanType - case c if c == java.lang.Byte.TYPE => ByteType - case c if c == java.lang.Short.TYPE => ShortType - case c if c == java.lang.Integer.TYPE => IntegerType - case c if c == java.lang.Long.TYPE => LongType - case c if c == java.lang.Float.TYPE => FloatType - case c if c == java.lang.Double.TYPE => DoubleType - case c if c == classOf[Array[Byte]] => BinaryType - case _ => ObjectType(cls) - } - - val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - - // Since we are creating a runtime mirror using the class loader of current thread, - // we need to use def at here. So, every time we call mirror, it is using the - // class loader of the current thread. - override def mirror: universe.Mirror = { - universe.runtimeMirror(Thread.currentThread().getContextClassLoader) - } - - import universe._ - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - - def isSubtype(t: universe.Type, t2: universe.Type): Boolean = t <:< t2 - - /** - * Synchronize to prevent concurrent usage of `<:<` operator. - * This operator is not thread safe in any current version of scala; i.e. - * (2.11.12, 2.12.10, 2.13.0-M5). - * - * See https://github.com/scala/bug/issues/10766 - */ - /* - private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { - ScalaReflection.ScalaSubtypeLock.synchronized { - tpe1 <:< tpe2 - } - } - */ - - private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { - tpe.dealias match { - case t if isSubtype(t, definitions.NullTpe) => NullType - case t if isSubtype(t, definitions.IntTpe) => IntegerType - case t if isSubtype(t, definitions.LongTpe) => LongType - case t if isSubtype(t, definitions.DoubleTpe) => DoubleType - case t if isSubtype(t, definitions.FloatTpe) => FloatType - case t if isSubtype(t, definitions.ShortTpe) => ShortType - case t if isSubtype(t, definitions.ByteTpe) => ByteType - case t if isSubtype(t, definitions.BooleanTpe) => BooleanType - case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType - case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType - case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT - case _ => - val className = getClassNameFromType(tpe) - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(elementType)) = tpe - arrayClassFor(elementType) - case _ => - val clazz = getClassFromType(tpe) - ObjectType(clazz) - } + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case c if c == classOf[Decimal] => DecimalType.SYSTEM_DEFAULT + case c if c == classOf[CalendarInterval] => CalendarIntervalType + case _ => ObjectType(cls) } - } - - /** - * Given a type `T` this function constructs `ObjectType` that holds a class of type - * `Array[T]`. - * - * Special handling is performed for primitive types to map them back to their raw - * JVM form instead of the Scala Array that handles auto boxing. - */ - private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { - val cls = tpe.dealias match { - case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]] - case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]] - case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]] - case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]] - case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]] - case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]] - case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 0).getClass + val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe + + // Since we are creating a runtime mirror using the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = { + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } - ObjectType(cls) - } - - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => true - case _ => false - } - - private def baseType(tpe: `Type`): `Type` = { - tpe.dealias match { - case annotatedType: AnnotatedType => annotatedType.underlying - case other => other + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + + def isSubtype(t: universe.Type, t2: universe.Type): Boolean = t <:< t2 + + /** + * Synchronize to prevent concurrent usage of `<:<` operator. + * This operator is not thread safe in any current version of scala; i.e. + * (2.11.12, 2.12.10, 2.13.0-M5). + * + * See https://github.com/scala/bug/issues/10766 + */ + /* + private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { + ScalaReflection.ScalaSubtypeLock.synchronized { + tpe1 <:< tpe2 + } + } + */ + + private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { + tpe.dealias match { + case t if isSubtype(t, definitions.NullTpe) => NullType + case t if isSubtype(t, definitions.IntTpe) => IntegerType + case t if isSubtype(t, definitions.LongTpe) => LongType + case t if isSubtype(t, definitions.DoubleTpe) => DoubleType + case t if isSubtype(t, definitions.FloatTpe) => FloatType + case t if isSubtype(t, definitions.ShortTpe) => ShortType + case t if isSubtype(t, definitions.ByteTpe) => ByteType + case t if isSubtype(t, definitions.BooleanTpe) => BooleanType + case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType + case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType + case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT + case _ => { + val className = getClassNameFromType(tpe) + className match { + case "scala.Array" => { + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + } + case _ => { + val clazz = getClassFromType(tpe) + ObjectType(clazz) + } + } + } + } } - } - - /** - * Returns an expression that can be used to deserialize a Spark SQL representation to an object - * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of - * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using - * `UnresolvedExtractValue`. - * - * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this - * deserializer expression when using it. - */ - def deserializerForType(tpe: `Type`): Expression = { - val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath().recordRoot(clsName) - val Schema(dataType, nullable) = schemaFor(tpe) - - // Assumes we are deserializing the first column of a row. - deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, - nullable = nullable, walkedTypePath, - (casted, typePath) => deserializerFor(tpe, casted, typePath)) - } - - - /** - * Returns an expression that can be used to deserialize an input expression to an object of type - * `T` with a compatible schema. - * - * @param tpe The `Type` of deserialized object. - * @param path The expression which can be used to extract serialized value. - * @param walkedTypePath The paths from top to bottom to access current field when deserializing. - */ - private def deserializerFor( - tpe: `Type`, - path: Expression, - walkedTypePath: WalkedTypePath, - predefinedDt: Option[DataTypeWithClass] = None - ): Expression = cleanUpReflectionObjects { - baseType(tpe) match { - - // - case t if isSubtype(t, localTypeOf[java.lang.Integer]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Integer]) - - case t if isSubtype(t, localTypeOf[Int]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Integer]) - - case t if isSubtype(t, localTypeOf[java.lang.Long]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Long]) - case t if isSubtype(t, localTypeOf[Long]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Long]) - - case t if isSubtype(t, localTypeOf[java.lang.Double]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Double]) - case t if isSubtype(t, localTypeOf[Double]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Double]) - - case t if isSubtype(t, localTypeOf[java.lang.Float]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Float]) - case t if isSubtype(t, localTypeOf[Float]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Float]) - - case t if isSubtype(t, localTypeOf[java.lang.Short]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Short]) - case t if isSubtype(t, localTypeOf[Short]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Short]) - - case t if isSubtype(t, localTypeOf[java.lang.Byte]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Byte]) - case t if isSubtype(t, localTypeOf[Byte]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Byte]) - - case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Boolean]) - case t if isSubtype(t, localTypeOf[Boolean]) => - createDeserializerForTypesSupportValueOf(path, - classOf[java.lang.Boolean]) - - case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => - createDeserializerForLocalDate(path) - - case t if isSubtype(t, localTypeOf[java.sql.Date]) => - createDeserializerForSqlDate(path) - // - - case t if isSubtype(t, localTypeOf[java.time.Instant]) => - createDeserializerForInstant(path) - - case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => - createDeserializerForSqlTimestamp(path) - - case t if isSubtype(t, localTypeOf[java.lang.String]) => - createDeserializerForString(path, returnNullable = false) - - case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => - createDeserializerForJavaBigDecimal(path, returnNullable = false) - - case t if isSubtype(t, localTypeOf[BigDecimal]) => - createDeserializerForScalaBigDecimal(path, returnNullable = false) - - case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => - createDeserializerForJavaBigInteger(path, returnNullable = false) - - case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => - createDeserializerForScalaBigInt(path) - - case t if isSubtype(t, localTypeOf[Array[_]]) => - var TypeRef(_, _, Seq(elementType)) = t - if (predefinedDt.isDefined && !elementType.dealias.typeSymbol.isClass) - elementType = getType(predefinedDt.get.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass].cls) - val Schema(dataType, elementNullable) = predefinedDt.map(it => { - val elementInfo = it.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType.asInstanceOf[DataTypeWithClass] - Schema(elementInfo.dt, elementInfo.nullable) - }) - .getOrElse(schemaFor(elementType)) - val className = getClassNameFromType(elementType) - val newTypePath = walkedTypePath.recordArray(className) - - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - (casted, typePath) => deserializerFor(elementType, casted, typePath, predefinedDt.map(_.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType).filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper]))) + + /** + * Given a type `T` this function constructs `ObjectType` that holds a class of type + * `Array[T]`. + * + * Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ + private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { + val cls = tpe.dealias match { + case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]] + case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]] + case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]] + case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]] + case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]] + case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]] + case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]] + case t if isSubtype(t, localTypeOf[Array[Byte]]) => classOf[Array[Array[Byte]]] + case t if isSubtype(t, localTypeOf[CalendarInterval]) => classOf[Array[CalendarInterval]] + case t if isSubtype(t, localTypeOf[Decimal]) => classOf[Array[Decimal]] + case other => { + // There is probably a better way to do this, but I couldn't find it... + val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls + java.lang.reflect.Array.newInstance(elementType, 0).getClass + } + } + ObjectType(cls) + } - val arrayData = UnresolvedMapObjects(mapFunction, path) - val arrayCls = arrayClassFor(elementType) - - val methodName = elementType match { - case t if isSubtype(t, definitions.IntTpe) => "toIntArray" - case t if isSubtype(t, definitions.LongTpe) => "toLongArray" - case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray" - case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray" - case t if isSubtype(t, definitions.ShortTpe) => "toShortArray" - case t if isSubtype(t, definitions.ByteTpe) => "toByteArray" - case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray" - // non-primitive - case _ => "array" + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => { + true } - Invoke(arrayData, methodName, arrayCls, returnNullable = false) + case _ => false + } - // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array - // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. + private def baseType(tpe: `Type`): `Type` = { + tpe.dealias match { + case annotatedType: AnnotatedType => annotatedType.underlying + case other => other + } + } - case t if isSubtype(t, localTypeOf[Map[_, _]]) => - val TypeRef(_, _, Seq(keyType, valueType)) = t + /** + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of + * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using + * `UnresolvedExtractValue`. + * + * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this + * deserializer expression when using it. + */ + def deserializerForType(tpe: `Type`): Expression = { + val clsName = getClassNameFromType(tpe) + val walkedTypePath = WalkedTypePath().recordRoot(clsName) + val Schema(dataType, nullable) = schemaFor(tpe) + + // Assumes we are deserializing the first column of a row. + deserializerForWithNullSafetyAndUpcast( + GetColumnByOrdinal(0, dataType), dataType, + nullable = nullable, walkedTypePath, + (casted, typePath) => deserializerFor(tpe, casted, typePath) + ) + } - val classNameForKey = getClassNameFromType(keyType) - val classNameForValue = getClassNameFromType(valueType) - val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) + /** + * Returns an expression that can be used to deserialize an input expression to an object of type + * `T` with a compatible schema. + * + * @param tpe The `Type` of deserialized object. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + */ + private def deserializerFor( + tpe: `Type`, + path: Expression, + walkedTypePath: WalkedTypePath, + predefinedDt: Option[DataTypeWithClass] = None + ): Expression = cleanUpReflectionObjects { + baseType(tpe) match { + + // + case t if ( + try { + !dataTypeFor(t).isInstanceOf[ObjectType] + } catch { + case _: Throwable => false + }) && !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => { + path + } - UnresolvedCatalystToExternalMap( - path, - p => deserializerFor(keyType, p, newTypePath), - p => deserializerFor(valueType, p, newTypePath), - mirror.runtimeClass(t.typeSymbol.asClass) - ) + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer]) + } + case t if isSubtype(t, localTypeOf[Int]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer]) + } + case t if isSubtype(t, localTypeOf[java.lang.Long]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long]) + } + case t if isSubtype(t, localTypeOf[Long]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long]) + } + case t if isSubtype(t, localTypeOf[java.lang.Double]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double]) + } + case t if isSubtype(t, localTypeOf[Double]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double]) + } + case t if isSubtype(t, localTypeOf[java.lang.Float]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float]) + } + case t if isSubtype(t, localTypeOf[Float]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float]) + } + case t if isSubtype(t, localTypeOf[java.lang.Short]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short]) + } + case t if isSubtype(t, localTypeOf[Short]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short]) + } + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte]) + } + case t if isSubtype(t, localTypeOf[Byte]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte]) + } + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean]) + } + case t if isSubtype(t, localTypeOf[Boolean]) => { + createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean]) + } + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => { + createDeserializerForLocalDate(path) + } + case t if isSubtype(t, localTypeOf[java.sql.Date]) => { + createDeserializerForSqlDate(path) + } // - case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => - createDeserializerForTypesSupportValueOf( - createDeserializerForString(path, returnNullable = false), Class.forName(t.toString)) - - case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). - getConstructor().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). - newInstance().asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) - - case _ if predefinedDt.isDefined => - predefinedDt.get match { - case wrapper: KDataTypeWrapper => - val structType = wrapper.dt - val cls = wrapper.cls - val arguments = structType - .fields - .map(field => { - val dataType = field.dataType.asInstanceOf[DataTypeWithClass] - val nullable = dataType.nullable - val clsName = getClassNameFromType(getType(dataType.cls)) - val newTypePath = walkedTypePath.recordField(clsName, field.name) - - // For tuples, we based grab the inner fields by ordinal instead of name. - val newPath = deserializerFor( - getType(dataType.cls), - addToPath(path, field.name, dataType.dt, newTypePath), - newTypePath, - Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) - ) - expressionWithNullSafety( - newPath, - nullable = nullable, - newTypePath + case t if isSubtype(t, localTypeOf[java.time.Instant]) => { + createDeserializerForInstant(path) + } + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => { + createDeserializerForTypesSupportValueOf( + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), + getClassFromType(t), ) + } + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => { + createDeserializerForSqlTimestamp(path) + } + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => { + createDeserializerForLocalDateTime(path) + } + case t if isSubtype(t, localTypeOf[java.time.Duration]) => { + createDeserializerForDuration(path) + } + case t if isSubtype(t, localTypeOf[java.time.Period]) => { + createDeserializerForPeriod(path) + } + case t if isSubtype(t, localTypeOf[java.lang.String]) => { + createDeserializerForString(path, returnNullable = false) + } + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => { + createDeserializerForJavaBigDecimal(path, returnNullable = false) + } + case t if isSubtype(t, localTypeOf[BigDecimal]) => { + createDeserializerForScalaBigDecimal(path, returnNullable = false) + } + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => { + createDeserializerForJavaBigInteger(path, returnNullable = false) + } + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => { + createDeserializerForScalaBigInt(path) + } - }) - val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + case t if isSubtype(t, localTypeOf[Array[_]]) => { + var TypeRef(_, _, Seq(elementType)) = t + if (predefinedDt.isDefined && !elementType.dealias.typeSymbol.isClass) + elementType = getType(predefinedDt.get.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType] + .elementType.asInstanceOf[DataTypeWithClass].cls + ) + val Schema(dataType, elementNullable) = predefinedDt.map { it => + val elementInfo = it.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType + .asInstanceOf[DataTypeWithClass] + Schema(elementInfo.dt, elementInfo.nullable) + }.getOrElse(schemaFor(elementType)) + val className = getClassNameFromType(elementType) + val newTypePath = walkedTypePath.recordArray(className) - org.apache.spark.sql.catalyst.expressions.If( - IsNull(path), - org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor( + tpe = elementType, + path = casted, + walkedTypePath = typePath, + predefinedDt = predefinedDt + .map(_.asInstanceOf[KComplexTypeWrapper].dt.asInstanceOf[ArrayType].elementType) + .filter(_.isInstanceOf[ComplexWrapper]) + .map(_.asInstanceOf[ComplexWrapper]) + ) + ) + } + + val arrayData = UnresolvedMapObjects(mapFunction, path) + val arrayCls = arrayClassFor(elementType) + + val methodName = elementType match { + case t if isSubtype(t, definitions.IntTpe) => "toIntArray" + case t if isSubtype(t, definitions.LongTpe) => "toLongArray" + case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray" + case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray" + case t if isSubtype(t, definitions.ShortTpe) => "toShortArray" + case t if isSubtype(t, definitions.ByteTpe) => "toByteArray" + case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray" + // non-primitive + case _ => "array" + } + Invoke(arrayData, methodName, arrayCls, returnNullable = false) + } + + // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array + // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. + + case t if isSubtype(t, localTypeOf[Map[_, _]]) => { + val TypeRef(_, _, Seq(keyType, valueType)) = t - case t: ComplexWrapper => - t.dt match { - case MapType(kt, vt, _) => - val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls).map(getType(_)) - val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass]) val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) - val keyData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(keyType, p, newTypePath, Some(keyDT).filter(_.isInstanceOf[ComplexWrapper])), - MapKeys(path)), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(valueType, p, newTypePath, Some(valueDT).filter(_.isInstanceOf[ComplexWrapper])), - MapValues(path)), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[java.util.Map[_, _]]), - "toJavaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case ArrayType(elementType, containsNull) => - val dataTypeWithClass = elementType.asInstanceOf[DataTypeWithClass] - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - val et = getType(dataTypeWithClass.cls) - val className = getClassNameFromType(et) - val newTypePath = walkedTypePath.recordArray(className) - deserializerForWithNullSafetyAndUpcast( - element, - dataTypeWithClass.dt, - nullable = dataTypeWithClass.nullable, - newTypePath, - (casted, typePath) => { - deserializerFor(et, casted, typePath, Some(dataTypeWithClass).filter(_.isInstanceOf[ComplexWrapper]).map(_.asInstanceOf[ComplexWrapper])) - }) + UnresolvedCatalystToExternalMap( + path, + p => deserializerFor(keyType, p, newTypePath), + p => deserializerFor(valueType, p, newTypePath), + mirror.runtimeClass(t.typeSymbol.asClass) + ) + } + + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => { + createDeserializerForTypesSupportValueOf( + createDeserializerForString(path, returnNullable = false), + Class.forName(t.toString), + ) + } + case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => { + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()) + ) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + } + + case t if UDTRegistration.exists(getClassNameFromType(t)) => { + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass) + ) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + } + + case _ if predefinedDt.isDefined => { + predefinedDt.get match { + + case wrapper: KDataTypeWrapper => { + val structType = wrapper.dt + val cls = wrapper.cls + val arguments = structType + .fields + .map { field => + val dataType = field.dataType.asInstanceOf[DataTypeWithClass] + val nullable = dataType.nullable + val clsName = getClassNameFromType(getType(dataType.cls)) + val newTypePath = walkedTypePath.recordField(clsName, field.name) + + // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = deserializerFor( + tpe = getType(dataType.cls), + path = addToPath(path, field.name, dataType.dt, newTypePath), + walkedTypePath = newTypePath, + predefinedDt = Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) + ) + expressionWithNullSafety( + newPath, + nullable = nullable, + newTypePath + ) + } + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + + org.apache.spark.sql.catalyst.expressions.If( + IsNull(path), + org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } + + case t: ComplexWrapper => { + + t.dt match { + case MapType(kt, vt, _) => { + val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls) + .map(getType(_)) + val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass]) + val classNameForKey = getClassNameFromType(keyType) + val classNameForValue = getClassNameFromType(valueType) + + val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) + + val keyData = + Invoke( + UnresolvedMapObjects( + p => deserializerFor( + keyType, p, newTypePath, Some(keyDT) + .filter(_.isInstanceOf[ComplexWrapper]) + ), + MapKeys(path) + ), + "array", + ObjectType(classOf[Array[Any]]) + ) + + val valueData = + Invoke( + UnresolvedMapObjects( + p => deserializerFor( + valueType, p, newTypePath, Some(valueDT) + .filter(_.isInstanceOf[ComplexWrapper]) + ), + MapValues(path) + ), + "array", + ObjectType(classOf[Array[Any]]) + ) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[java.util.Map[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil, + returnNullable = false + ) + } + + case ArrayType(elementType, containsNull) => { + val dataTypeWithClass = elementType.asInstanceOf[DataTypeWithClass] + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val et = getType(dataTypeWithClass.cls) + val className = getClassNameFromType(et) + val newTypePath = walkedTypePath.recordArray(className) + deserializerForWithNullSafetyAndUpcast( + element, + dataTypeWithClass.dt, + nullable = dataTypeWithClass.nullable, + newTypePath, + (casted, typePath) => { + deserializerFor( + et, casted, typePath, Some(dataTypeWithClass) + .filter(_.isInstanceOf[ComplexWrapper]) + .map(_.asInstanceOf[ComplexWrapper]) + ) + } + ) + } + + UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls)) + } + + case StructType(elementType: Array[StructField]) => { + val cls = t.cls + + val arguments = elementType.map { field => + val dataType = field.dataType.asInstanceOf[DataTypeWithClass] + val nullable = dataType.nullable + val clsName = getClassNameFromType(getType(dataType.cls)) + val newTypePath = walkedTypePath.recordField(clsName, field.name) + + // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = deserializerFor( + getType(dataType.cls), + addToPath(path, field.name, dataType.dt, newTypePath), + newTypePath, + Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) + ) + expressionWithNullSafety( + newPath, + nullable = nullable, + newTypePath + ) + } + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + + org.apache.spark.sql.catalyst.expressions.If( + IsNull(path), + org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } + + case _ => { + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath + ) + } + } + } } + } - UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls)) - - case StructType(elementType: Array[StructField]) => - val cls = t.cls - - val arguments = elementType.map { field => - val dataType = field.dataType.asInstanceOf[DataTypeWithClass] - val nullable = dataType.nullable - val clsName = getClassNameFromType(getType(dataType.cls)) - val newTypePath = walkedTypePath.recordField(clsName, field.name) - - // For tuples, we based grab the inner fields by ordinal instead of name. - val newPath = deserializerFor( - getType(dataType.cls), - addToPath(path, field.name, dataType.dt, newTypePath), - newTypePath, - Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) - ) - expressionWithNullSafety( - newPath, - nullable = nullable, - newTypePath - ) + case t if definedByConstructorParams(t) => { + val params = getConstructorParameters(t) + + val cls = getClassFromType(tpe) + + val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val Schema(dataType, nullable) = schemaFor(fieldType) + val clsName = getClassNameFromType(fieldType) + val newTypePath = walkedTypePath.recordField(clsName, fieldName) + + // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(path, i, dataType, newTypePath), + newTypePath + ) + } else { + deserializerFor( + fieldType, + addToPath(path, fieldName, dataType, newTypePath), + newTypePath + ) + } + expressionWithNullSafety( + newPath, + nullable = nullable, + newTypePath + ) } + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) org.apache.spark.sql.catalyst.expressions.If( - IsNull(path), - org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), - newInstance + IsNull(path), + org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), + newInstance ) + } - - case _ => + case _ => { throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath) + s"No Encoder found for $tpe\n" + walkedTypePath + ) } } + } + + /** + * Returns an expression for serializing an object of type T to Spark SQL representation. The + * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ + def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { + val clsName = getClassNameFromType(tpe) + val walkedTypePath = WalkedTypePath().recordRoot(clsName) + + // The input object to `ExpressionEncoder` is located at first column of an row. + val isPrimitive = tpe.typeSymbol.asClass.isPrimitive + val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive) + + serializerFor(inputObject, tpe, walkedTypePath) + } - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - - val cls = getClassFromType(tpe) - - val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => - val Schema(dataType, nullable) = schemaFor(fieldType) - val clsName = getClassNameFromType(fieldType) - val newTypePath = walkedTypePath.recordField(clsName, fieldName) - - // For tuples, we based grab the inner fields by ordinal instead of name. - val newPath = if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(path, i, dataType, newTypePath), - newTypePath) - } else { - deserializerFor( - fieldType, - addToPath(path, fieldName, dataType, newTypePath), - newTypePath) - } - expressionWithNullSafety( - newPath, - nullable = nullable, - newTypePath) + def getType[T](clazz: Class[T]): universe.Type = { + clazz match { + case _ if clazz == classOf[Array[Byte]] => localTypeOf[Array[Byte]] + case _ => { + val mir = runtimeMirror(clazz.getClassLoader) + mir.classSymbol(clazz).toType + } } - val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + } - org.apache.spark.sql.catalyst.expressions.If( - IsNull(path), - org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), - newInstance + def deserializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = { + val tpe = getType(cls) + val clsName = getClassNameFromType(tpe) + val walkedTypePath = WalkedTypePath().recordRoot(clsName) + + // Assumes we are deserializing the first column of a row. + deserializerForWithNullSafetyAndUpcast( + GetColumnByOrdinal(0, dt.dt), + dt.dt, + nullable = dt.nullable, + walkedTypePath, + (casted, typePath) => deserializerFor(tpe, casted, typePath, Some(dt)) ) - - case _ => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath) } - } - - /** - * Returns an expression for serializing an object of type T to Spark SQL representation. The - * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. - * - * If the given type is not supported, i.e. there is no encoder can be built for this type, - * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain - * the type path walked so far and which class we are not supporting. - * There are 4 kinds of type path: - * * the root type: `root class: "abc.xyz.MyClass"` - * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` - * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` - * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` - */ - def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { - val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath().recordRoot(clsName) - - // The input object to `ExpressionEncoder` is located at first column of an row. - val isPrimitive = tpe.typeSymbol.asClass.isPrimitive - val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive) - - serializerFor(inputObject, tpe, walkedTypePath) - } - - def getType[T](clazz: Class[T]): universe.Type = { - val mir = runtimeMirror(clazz.getClassLoader) - mir.classSymbol(clazz).toType - } - - def deserializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = { - val tpe = getType(cls) - val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath().recordRoot(clsName) - - // Assumes we are deserializing the first column of a row. - deserializerForWithNullSafetyAndUpcast( - GetColumnByOrdinal(0, dt.dt), - dt.dt, - nullable = dt.nullable, - walkedTypePath, - (casted, typePath) => deserializerFor(tpe, casted, typePath, Some(dt)) - ) - } - - - def serializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = { - - val tpe = getType(cls) - val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath().recordRoot(clsName) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - serializerFor(inputObject, tpe, walkedTypePath, predefinedDt = Some(dt)) - } - - /** - * Returns an expression for serializing the value of an input expression into Spark SQL - * internal representation. - */ - private def serializerFor( - inputObject: Expression, - tpe: `Type`, - walkedTypePath: WalkedTypePath, - seenTypeSet: Set[`Type`] = Set.empty, - predefinedDt: Option[DataTypeWithClass] = None - ): Expression = cleanUpReflectionObjects { - - def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = { - predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match { - - case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => - val clsName = getClassNameFromType(elementType) - val newPath = walkedTypePath.recordArray(clsName) - createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls), - serializerFor(_, elementType, newPath, seenTypeSet, predefinedDt)) - - case dt: ObjectType => - val clsName = getClassNameFromType(elementType) - val newPath = walkedTypePath.recordArray(clsName) - createSerializerForMapObjects(input, dt, - serializerFor(_, elementType, newPath, seenTypeSet)) - - case dt@(BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType) => - val cls = input.dataType.asInstanceOf[ObjectType].cls - if (cls.isArray && cls.getComponentType.isPrimitive) { - createSerializerForPrimitiveArray(input, dt) - } else { - createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable)) - } - - case _: StringType => - val clsName = getClassNameFromType(typeOf[String]) - val newPath = walkedTypePath.recordArray(clsName) - createSerializerForMapObjects(input, ObjectType(Class.forName(getClassNameFromType(elementType))), - serializerFor(_, elementType, newPath, seenTypeSet)) - - - case dt => - createSerializerForGenericArray(input, dt, nullable = predefinedDt.map(_.nullable).getOrElse(schemaFor(elementType).nullable)) - } + + + def serializerFor(cls: java.lang.Class[_], dt: DataTypeWithClass): Expression = { + val tpe = getType(cls) + val clsName = getClassNameFromType(tpe) + val walkedTypePath = WalkedTypePath().recordRoot(clsName) + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + serializerFor(inputObject, tpe, walkedTypePath, predefinedDt = Some(dt)) } - baseType(tpe) match { - - // - case _ if !inputObject.dataType.isInstanceOf[ObjectType] && !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => inputObject - - case t if isSubtype(t, localTypeOf[Option[_]]) => - val TypeRef(_, _, Seq(optType)) = t - val className = getClassNameFromType(optType) - val newPath = walkedTypePath.recordOption(className) - val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath, seenTypeSet) - - // Since List[_] also belongs to localTypeOf[Product], we put this case before - // "case t if definedByConstructorParams(t)" to make sure it will match to the - // case "localTypeOf[Seq[_]]" - case t if isSubtype(t, localTypeOf[Seq[_]]) => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if isSubtype(t, localTypeOf[Array[_]]) && predefinedDt.isEmpty => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if isSubtype(t, localTypeOf[Map[_, _]]) => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyClsName = getClassNameFromType(keyType) - val valueClsName = getClassNameFromType(valueType) - val keyPath = walkedTypePath.recordKeyForMap(keyClsName) - val valuePath = walkedTypePath.recordValueForMap(valueClsName) - - createSerializerForMap( - inputObject, - MapElementInformation( - dataTypeFor(keyType), - nullable = !keyType.typeSymbol.asClass.isPrimitive, - serializerFor(_, keyType, keyPath, seenTypeSet)), - MapElementInformation( - dataTypeFor(valueType), - nullable = !valueType.typeSymbol.asClass.isPrimitive, - serializerFor(_, valueType, valuePath, seenTypeSet)) - ) + /** + * Returns an expression for serializing the value of an input expression into Spark SQL + * internal representation. + */ + private def serializerFor( + inputObject: Expression, + tpe: `Type`, + walkedTypePath: WalkedTypePath, + seenTypeSet: Set[`Type`] = Set.empty, + predefinedDt: Option[DataTypeWithClass] = None, + ): Expression = cleanUpReflectionObjects { + + def toCatalystArray( + input: Expression, + elementType: `Type`, + predefinedDt: Option[DataTypeWithClass] = None, + ): Expression = { + val dataType = predefinedDt + .map(_.dt) + .getOrElse { + dataTypeFor(elementType) + } - case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => - val TypeRef(_, _, Seq(elementType)) = t + dataType match { - // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. - // Note that the property of `Set` is only kept when manipulating the data as domain object. - val newInput = - Invoke( - inputObject, - "toSeq", - ObjectType(classOf[Seq[_]])) - - toCatalystArray(newInput, elementType) - - case t if isSubtype(t, localTypeOf[String]) => - createSerializerForString(inputObject) - case t if isSubtype(t, localTypeOf[java.time.Instant]) => - createSerializerForJavaInstant(inputObject) - - case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => - createSerializerForSqlTimestamp(inputObject) - - case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => - createSerializerForJavaLocalDate(inputObject) - - case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject) - - case t if isSubtype(t, localTypeOf[BigDecimal]) => - createSerializerForScalaBigDecimal(inputObject) - - case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => - createSerializerForJavaBigDecimal(inputObject) - - case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => - createSerializerForJavaBigInteger(inputObject) - - case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => - createSerializerForScalaBigInt(inputObject) - - case t if isSubtype(t, localTypeOf[java.lang.Integer]) => - createSerializerForInteger(inputObject) - case t if isSubtype(t, localTypeOf[Int]) => - createSerializerForInteger(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Long]) => createSerializerForLong(inputObject) - case t if isSubtype(t, localTypeOf[Long]) => createSerializerForLong(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Double]) => createSerializerForDouble(inputObject) - case t if isSubtype(t, localTypeOf[Double]) => createSerializerForDouble(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Float]) => createSerializerForFloat(inputObject) - case t if isSubtype(t, localTypeOf[Float]) => createSerializerForFloat(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Short]) => createSerializerForShort(inputObject) - case t if isSubtype(t, localTypeOf[Short]) => createSerializerForShort(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Byte]) => createSerializerForByte(inputObject) - case t if isSubtype(t, localTypeOf[Byte]) => createSerializerForByte(inputObject) - case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => createSerializerForBoolean(inputObject) - case t if isSubtype(t, localTypeOf[Boolean]) => createSerializerForBoolean(inputObject) - - case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => - createSerializerForString( - Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) - - case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() - val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() - createSerializerForUserDefinedType(inputObject, udt, udtClass) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). - newInstance().asInstanceOf[UserDefinedType[_]] - val udtClass = udt.getClass - createSerializerForUserDefinedType(inputObject, udt, udtClass) - // - - case _ if predefinedDt.isDefined => - predefinedDt.get match { - case dataType: KDataTypeWrapper => - val cls = dataType.cls - val properties = getJavaBeanReadableProperties(cls) - val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField]) - val fields = structFields.map { structField => - val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName) - if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}") - val fieldName = structField.name - val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls - val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] - val fieldValue = Invoke( - inputObject, - maybeProp.get.getReadMethod.getName, - inferExternalType(propClass), - returnNullable = structField.nullable - ) - val newPath = walkedTypePath.recordField(propClass.getName, fieldName) - (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None)) - - } - createSerializerForObject(inputObject, fields) - - case otherTypeWrapper: ComplexWrapper => - otherTypeWrapper.dt match { - case MapType(kt, vt, _) => - val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls).map(getType(_)) - val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass]) + case dt @ (MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => { + val clsName = getClassNameFromType(elementType) + val newPath = walkedTypePath.recordArray(clsName) + createSerializerForMapObjects( + input, ObjectType(predefinedDt.get.cls), + serializerFor(_, elementType, newPath, seenTypeSet, predefinedDt) + ) + } + + case dt: ObjectType => { + val clsName = getClassNameFromType(elementType) + val newPath = walkedTypePath.recordArray(clsName) + createSerializerForMapObjects( + input, dt, + serializerFor(_, elementType, newPath, seenTypeSet) + ) + } + + // case dt: ByteType => + // createSerializerForPrimitiveArray(input, dt) + + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => { + val cls = input.dataType.asInstanceOf[ObjectType].cls + if (cls.isArray && cls.getComponentType.isPrimitive) { + createSerializerForPrimitiveArray(input, dt) + } else { + createSerializerForGenericArray( + inputObject = input, + dataType = dt, + nullable = predefinedDt + .map(_.nullable) + .getOrElse( + schemaFor(elementType).nullable + ), + ) + } + } + + case _: StringType => { + val clsName = getClassNameFromType(typeOf[String]) + val newPath = walkedTypePath.recordArray(clsName) + createSerializerForMapObjects( + input, ObjectType(Class.forName(getClassNameFromType(elementType))), + serializerFor(_, elementType, newPath, seenTypeSet) + ) + } + + case dt => { + createSerializerForGenericArray( + inputObject = input, + dataType = dt, + nullable = predefinedDt + .map(_.nullable) + .getOrElse { + schemaFor(elementType).nullable + }, + ) + } + } + } + + baseType(tpe) match { + + // + case _ if !inputObject.dataType.isInstanceOf[ObjectType] && + !predefinedDt.exists(_.isInstanceOf[ComplexWrapper]) => { + inputObject + } + case t if isSubtype(t, localTypeOf[Option[_]]) => { + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newPath = walkedTypePath.recordOption(className) + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath, seenTypeSet) + } + + // Since List[_] also belongs to localTypeOf[Product], we put this case before + // "case t if definedByConstructorParams(t)" to make sure it will match to the + // case "localTypeOf[Seq[_]]" + case t if isSubtype(t, localTypeOf[Seq[_]]) => { + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + } + + case t if isSubtype(t, localTypeOf[Array[_]]) && predefinedDt.isEmpty => { + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + } + + case t if isSubtype(t, localTypeOf[Map[_, _]]) => { + val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) val keyPath = walkedTypePath.recordKeyForMap(keyClsName) val valuePath = walkedTypePath.recordValueForMap(valueClsName) createSerializerForMap( - inputObject, - MapElementInformation( - dataTypeFor(keyType), - nullable = !keyType.typeSymbol.asClass.isPrimitive, - serializerFor(_, keyType, keyPath, seenTypeSet, Some(keyDT).filter(_.isInstanceOf[ComplexWrapper]))), - MapElementInformation( - dataTypeFor(valueType), - nullable = !valueType.typeSymbol.asClass.isPrimitive, - serializerFor(_, valueType, valuePath, seenTypeSet, Some(valueDT).filter(_.isInstanceOf[ComplexWrapper]))) + inputObject, + MapElementInformation( + dataTypeFor(keyType), + nullable = !keyType.typeSymbol.asClass.isPrimitive, + serializerFor(_, keyType, keyPath, seenTypeSet) + ), + MapElementInformation( + dataTypeFor(valueType), + nullable = !valueType.typeSymbol.asClass.isPrimitive, + serializerFor(_, valueType, valuePath, seenTypeSet) + ) ) - case ArrayType(elementType, _) => - toCatalystArray(inputObject, getType(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass])) + } - case StructType(elementType: Array[StructField]) => - val cls = otherTypeWrapper.cls - val names = elementType.map(_.name) + case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => { + val TypeRef(_, _, Seq(elementType)) = t - val beanInfo = Introspector.getBeanInfo(cls) - val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName)) + // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. + // Note that the property of `Set` is only kept when manipulating the data as domain object. + val newInput = + Invoke( + inputObject, + "toSeq", + ObjectType(classOf[Seq[_]]) + ) + toCatalystArray(newInput, elementType) + } - val fields = elementType.map { structField => + case t if isSubtype(t, localTypeOf[String]) => { + createSerializerForString(inputObject) + } + case t if isSubtype(t, localTypeOf[java.time.Instant]) => { + createSerializerForJavaInstant(inputObject) + } + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => { + createSerializerForSqlTimestamp(inputObject) + } + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => { + createSerializerForLocalDateTime(inputObject) + } + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => { + createSerializerForJavaLocalDate(inputObject) + } + case t if isSubtype(t, localTypeOf[java.sql.Date]) => { + createSerializerForSqlDate(inputObject) + } + case t if isSubtype(t, localTypeOf[java.time.Duration]) => { + createSerializerForJavaDuration(inputObject) + } + case t if isSubtype(t, localTypeOf[java.time.Period]) => { + createSerializerForJavaPeriod(inputObject) + } + case t if isSubtype(t, localTypeOf[BigDecimal]) => { + createSerializerForScalaBigDecimal(inputObject) + } + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => { + createSerializerForJavaBigDecimal(inputObject) + } + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => { + createSerializerForJavaBigInteger(inputObject) + } + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => { + createSerializerForScalaBigInt(inputObject) + } - val maybeProp = methods.find(it => it.getName == structField.name) - if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}") - val fieldName = structField.name - val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls - val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] - val fieldValue = Invoke( - inputObject, - maybeProp.get.getName, - inferExternalType(propClass), - returnNullable = propDt.nullable - ) - val newPath = walkedTypePath.recordField(propClass.getName, fieldName) - (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None)) + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => { + createSerializerForInteger(inputObject) + } + case t if isSubtype(t, localTypeOf[Int]) => { + createSerializerForInteger(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Long]) => { + createSerializerForLong(inputObject) + } + case t if isSubtype(t, localTypeOf[Long]) => { + createSerializerForLong(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Double]) => { + createSerializerForDouble(inputObject) + } + case t if isSubtype(t, localTypeOf[Double]) => { + createSerializerForDouble(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Float]) => { + createSerializerForFloat(inputObject) + } + case t if isSubtype(t, localTypeOf[Float]) => { + createSerializerForFloat(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Short]) => { + createSerializerForShort(inputObject) + } + case t if isSubtype(t, localTypeOf[Short]) => { + createSerializerForShort(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => { + createSerializerForByte(inputObject) + } + case t if isSubtype(t, localTypeOf[Byte]) => { + createSerializerForByte(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => { + createSerializerForBoolean(inputObject) + } + case t if isSubtype(t, localTypeOf[Boolean]) => { + createSerializerForBoolean(inputObject) + } + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => { + createSerializerForString( + Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) + ) + } + case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => { + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() + val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() + createSerializerForUserDefinedType(inputObject, udt, udtClass) + } + case t if UDTRegistration.exists(getClassNameFromType(t)) => { + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). + newInstance().asInstanceOf[UserDefinedType[_]] + val udtClass = udt.getClass + createSerializerForUserDefinedType(inputObject, udt, udtClass) + } + // + + case _ if predefinedDt.isDefined => { + predefinedDt.get match { + + case dataType: KDataTypeWrapper => { + val cls = dataType.cls + val properties = getJavaBeanReadableProperties(cls) + val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField]) + val fields: Array[(String, Expression)] = structFields.map { structField => + val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName) + if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${ + structField.name + } is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}" + ) + val fieldName = structField.name + val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls + val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] + + val fieldValue = Invoke( + inputObject, + maybeProp.get.getReadMethod.getName, + inferExternalType(propClass), + returnNullable = structField.nullable + ) + val newPath = walkedTypePath.recordField(propClass.getName, fieldName) + + val tpe = getType(propClass) + + val serializer = serializerFor( + inputObject = fieldValue, + tpe = tpe, + walkedTypePath = newPath, + seenTypeSet = seenTypeSet, + predefinedDt = if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None + ) + + (fieldName, serializer) + } + createSerializerForObject(inputObject, fields) + } + + case otherTypeWrapper: ComplexWrapper => { + + otherTypeWrapper.dt match { + + case MapType(kt, vt, _) => { + val Seq(keyType, valueType) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass].cls) + .map(getType(_)) + val Seq(keyDT, valueDT) = Seq(kt, vt).map(_.asInstanceOf[DataTypeWithClass]) + val keyClsName = getClassNameFromType(keyType) + val valueClsName = getClassNameFromType(valueType) + val keyPath = walkedTypePath.recordKeyForMap(keyClsName) + val valuePath = walkedTypePath.recordValueForMap(valueClsName) + + createSerializerForMap( + inputObject, + MapElementInformation( + dataTypeFor(keyType), + nullable = !keyType.typeSymbol.asClass.isPrimitive, + serializerFor( + _, keyType, keyPath, seenTypeSet, Some(keyDT) + .filter(_.isInstanceOf[ComplexWrapper]) + ) + ), + MapElementInformation( + dataTypeFor(valueType), + nullable = !valueType.typeSymbol.asClass.isPrimitive, + serializerFor( + _, valueType, valuePath, seenTypeSet, Some(valueDT) + .filter(_.isInstanceOf[ComplexWrapper]) + ) + ) + ) + } + + case ArrayType(elementType, _) => { + toCatalystArray( + inputObject, + getType(elementType.asInstanceOf[DataTypeWithClass].cls + ), Some(elementType.asInstanceOf[DataTypeWithClass]) + ) + } + + case StructType(elementType: Array[StructField]) => { + val cls = otherTypeWrapper.cls + val names = elementType.map(_.name) + + val beanInfo = Introspector.getBeanInfo(cls) + val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName)) + + + val fields = elementType.map { structField => + + val maybeProp = methods.find(it => it.getName == structField.name) + if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${ + structField.name + } is not found among available props, which are: ${ + methods.map(_.getName).mkString(", ") + }" + ) + val fieldName = structField.name + val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls + val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] + val fieldValue = Invoke( + inputObject, + maybeProp.get.getName, + inferExternalType(propClass), + returnNullable = propDt.nullable + ) + val newPath = walkedTypePath.recordField(propClass.getName, fieldName) + (fieldName, serializerFor( + fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt + .isInstanceOf[ComplexWrapper]) Some(propDt) else None + )) + + } + createSerializerForObject(inputObject, fields) + } + + case _ => { + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath + ) + } + } + } + } + } + + case t if definedByConstructorParams(t) => { + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t" + ) + } + + val params = getConstructorParameters(t) + val fields = params.map { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath + ) + } + + // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul + // is necessary here. Because for a nullable nested inputObject with struct data + // type, e.g. StructType(IntegerType, StringType), it will return nullable=true + // for IntegerType without KnownNotNull. And that's what we do not expect to. + val fieldValue = Invoke( + KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), + returnNullable = !fieldType.typeSymbol.asClass.isPrimitive + ) + val clsName = getClassNameFromType(fieldType) + val newPath = walkedTypePath.recordField(clsName, fieldName) + (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) } createSerializerForObject(inputObject, fields) + } - case _ => + case _ => { throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath) - + s"No Encoder found for $tpe\n" + walkedTypePath + ) } } + } - case t if definedByConstructorParams(t) => - if (seenTypeSet.contains(t)) { - throw new UnsupportedOperationException( - s"cannot have circular references in class, but got the circular reference of class $t") - } + def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = { + Invoke( + path, "toString", ObjectType(classOf[java.lang.String]), + returnNullable = returnNullable + ) + } - val params = getConstructorParameters(t) - val fields = params.map { case (fieldName, fieldType) => - if (javaKeywords.contains(fieldName)) { - throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath) - } - - // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul - // is necessary here. Because for a nullable nested inputObject with struct data - // type, e.g. StructType(IntegerType, StringType), it will return nullable=true - // for IntegerType without KnownNotNull. And that's what we do not expect to. - val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), - returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) - val clsName = getClassNameFromType(fieldType) - val newPath = walkedTypePath.recordField(clsName, fieldName) - (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) - } - createSerializerForObject(inputObject, fields) + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + .filterNot(_.getName == "declaringClass") + .filter(_.getReadMethod != null) + } + + /* + * Retrieves the runtime class corresponding to the provided type. + */ + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { + + baseType(tpe) match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true) - case _ => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath) + case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => { + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). + getConstructor().newInstance() + Schema(udt, nullable = true) + } + case t if UDTRegistration.exists(getClassNameFromType(t)) => { + val udt = UDTRegistration + .getUDTFor(getClassNameFromType(t)) + .get + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[_]] + Schema(udt, nullable = true) + } + case t if isSubtype(t, localTypeOf[Option[_]]) => { + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + } + case t if isSubtype(t, localTypeOf[Array[Byte]]) => { + Schema(BinaryType, nullable = true) + } + case t if isSubtype(t, localTypeOf[Array[_]]) => { + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + } + case t if isSubtype(t, localTypeOf[Seq[_]]) => { + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + } + case t if isSubtype(t, localTypeOf[Map[_, _]]) => { + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema( + MapType( + schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable + ), nullable = true + ) + } + case t if isSubtype(t, localTypeOf[Set[_]]) => { + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + } + case t if isSubtype(t, localTypeOf[String]) => { + Schema(StringType, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.time.Instant]) => { + Schema(TimestampType, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => { + Schema(TimestampType, nullable = true) + } + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) && Utils.isTesting => { + Schema(TimestampNTZType, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => { + Schema(DateType, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.sql.Date]) => { + Schema(DateType, nullable = true) + } + case t if isSubtype(t, localTypeOf[CalendarInterval]) => { + Schema(CalendarIntervalType, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.time.Duration]) => { + Schema(DayTimeIntervalType(), nullable = true) + } + case t if isSubtype(t, localTypeOf[java.time.Period]) => { + Schema(YearMonthIntervalType(), nullable = true) + } + case t if isSubtype(t, localTypeOf[BigDecimal]) => { + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => { + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => { + Schema(DecimalType.BigIntDecimal, nullable = true) + } + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => { + Schema(DecimalType.BigIntDecimal, nullable = true) + } + case t if isSubtype(t, localTypeOf[Decimal]) => { + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + } + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true) + case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false) + case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false) + case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false) + case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false) + case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false) + case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) + case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) + case t if definedByConstructorParams(t) => { + val params = getConstructorParameters(t) + Schema( + StructType( + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) + } + ), nullable = true + ) + } + case other => { + throw new UnsupportedOperationException(s"Schema for type $other is not supported") + } + } } - } - - def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = { - Invoke(path, "toString", ObjectType(classOf[java.lang.String]), - returnNullable = returnNullable) - } - - def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { - val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - .filterNot(_.getName == "declaringClass") - .filter(_.getReadMethod != null) - } - - /* - * Retrieves the runtime class corresponding to the provided type. - */ - def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) - - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { - baseType(tpe) match { - // this must be the first case, since all objects in scala are instances of Null, therefore - // Null type would wrongly match the first of them, which is Option as of now - case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true) - case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). - getConstructor().newInstance() - Schema(udt, nullable = true) - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). - newInstance().asInstanceOf[UserDefinedType[_]] - Schema(udt, nullable = true) - case t if isSubtype(t, localTypeOf[Option[_]]) => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, nullable = true) - case t if isSubtype(t, localTypeOf[Array[_]]) => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if isSubtype(t, localTypeOf[Seq[_]]) => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if isSubtype(t, localTypeOf[Map[_, _]]) => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if isSubtype(t, localTypeOf[Set[_]]) => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, nullable = true) - case t if isSubtype(t, localTypeOf[java.time.Instant]) => - Schema(TimestampType, nullable = true) - case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => - Schema(TimestampType, nullable = true) - case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true) - case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true) - case t if isSubtype(t, localTypeOf[BigDecimal]) => - Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => - Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => - Schema(DecimalType.BigIntDecimal, nullable = true) - case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => - Schema(DecimalType.BigIntDecimal, nullable = true) - case t if isSubtype(t, localTypeOf[Decimal]) => - Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true) - case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true) - case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false) - case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false) - case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false) - case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false) - case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false) - case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) - case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - Schema(StructType( - params.map { case (fieldName, fieldType) => - val Schema(dataType, nullable) = schemaFor(fieldType) - StructField(fieldName, dataType, nullable) - }), nullable = true) - case other => - throw new UnsupportedOperationException(s"Schema for type $other is not supported") + + /** + * Whether the fields of the given type is defined entirely by its constructor parameters. + */ + def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head) + case _ => { + isSubtype(tpe.dealias, localTypeOf[Product]) || + isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams]) + } + } } - } - - /** - * Whether the fields of the given type is defined entirely by its constructor parameters. - */ - def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias match { - // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. - case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head) - case _ => isSubtype(tpe.dealias, localTypeOf[Product]) || - isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams]) + + private val javaKeywords = Set( + "abstract", "assert", "boolean", "break", "byte", "case", "catch", + "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", + "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", + "interface", "long", "native", "new", "null", "package", "private", "protected", "public", + "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", + "throws", "transient", "true", "try", "void", "volatile", "while" + ) + + + @scala.annotation.tailrec + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case _: DayTimeIntervalType => classOf[java.lang.Long] + case _: YearMonthIntervalType => classOf[java.lang.Integer] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) } - } - - private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", - "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", - "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", - "interface", "long", "native", "new", "null", "package", "private", "protected", "public", - "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", - "throws", "transient", "true", "try", "void", "volatile", "while") - - - @scala.annotation.tailrec - def javaBoxedType(dt: DataType): Class[_] = dt match { - case _: DecimalType => classOf[Decimal] - case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] - case CalendarIntervalType => classOf[CalendarInterval] - case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayType] - case _: MapType => classOf[MapType] - case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) - case ObjectType(cls) => cls - case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) - } } @@ -991,120 +1278,124 @@ object KotlinReflection extends KotlinReflection { * object, this trait able to work in both the runtime and the compile time (macro) universe. */ trait KotlinReflection extends Logging { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror - - import universe._ - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - - /** - * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to - * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to - * `scala.reflect.runtime.JavaUniverse.undoLog`. - * - * @see https://github.com/scala/bug/issues/8302 - */ - def cleanUpReflectionObjects[T](func: => T): T = { - universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) - } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = { - val tag = implicitly[TypeTag[T]] - tag.in(mirror).tpe.dealias - } - - /** - * Returns the full class name for a type. The returned name is the canonical - * Scala name, where each component is separated by a period. It is NOT the - * Java-equivalent runtime name (no dollar signs). - * - * In simple cases, both the Scala and Java names are the same, however when Scala - * generates constructs that do not map to a Java equivalent, such as singleton objects - * or nested classes in package objects, it uses the dollar sign ($) to create - * synthetic classes, emulating behaviour in Java bytecode. - */ - def getClassNameFromType(tpe: `Type`): String = { - tpe.dealias.erasure.typeSymbol.asClass.fullName - } - - /** - * Returns the parameter names and types for the primary constructor of this type. - * - * Note that it only works for scala classes with primary constructor, and currently doesn't - * support inner class. - */ - def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { - val dealiasedTpe = tpe.dealias - val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = dealiasedTpe - val params = constructParams(dealiasedTpe) - // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) - if (actualTypeArgs.nonEmpty) { - params.map { p => - p.name.decodedName.toString -> - p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } - } else { - params.map { p => - p.name.decodedName.toString -> p.typeSignature - } + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + + /** + * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to + * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to + * `scala.reflect.runtime.JavaUniverse.undoLog`. + * + * @see https://github.com/scala/bug/issues/8302 + */ + def cleanUpReflectionObjects[T](func: => T): T = { + universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) } - } - - /** - * If our type is a Scala trait it may have a companion object that - * only defines a constructor via `apply` method. - */ - private def getCompanionConstructor(tpe: Type): Symbol = { - def throwUnsupportedOperation = { - throw new UnsupportedOperationException(s"Unable to find constructor for $tpe. " + - s"This could happen if $tpe is an interface, or a trait without companion object " + - "constructor.") + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.dealias } - tpe.typeSymbol.asClass.companion match { - case NoSymbol => throwUnsupportedOperation - case sym => sym.asTerm.typeSignature.member(universe.TermName("apply")) match { - case NoSymbol => throwUnsupportedOperation - case constructorSym => constructorSym - } + /** + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ + def getClassNameFromType(tpe: `Type`): String = { + tpe.dealias.erasure.typeSymbol.asClass.fullName } - } - protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { - case NoSymbol => getCompanionConstructor(tpe) - case sym => sym + /** + * Returns the parameter names and types for the primary constructor of this type. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { + val dealiasedTpe = tpe.dealias + val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = dealiasedTpe + val params = constructParams(dealiasedTpe) + // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) + if (actualTypeArgs.nonEmpty) { + params.map { p => + p.name.decodedName.toString -> + p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } else { + params.map { p => + p.name.decodedName.toString -> p.typeSignature + } + } } - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramLists - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( - s => s.isMethod && s.asMethod.isPrimaryConstructor) - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramLists - } + + /** + * If our type is a Scala trait it may have a companion object that + * only defines a constructor via `apply` method. + */ + private def getCompanionConstructor(tpe: Type): Symbol = { + def throwUnsupportedOperation = { + throw new UnsupportedOperationException(s"Unable to find constructor for $tpe. " + + s"This could happen if $tpe is an interface, or a trait without companion object " + + "constructor." + ) + } + + tpe.typeSymbol.asClass.companion match { + case NoSymbol => throwUnsupportedOperation + case sym => { + sym.asTerm.typeSignature.member(universe.TermName("apply")) match { + case NoSymbol => throwUnsupportedOperation + case constructorSym => constructorSym + } + } + } + } + + protected def constructParams(tpe: Type): Seq[Symbol] = { + val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match { + case NoSymbol => getCompanionConstructor(tpe) + case sym => sym + } + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramLists + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( + s => s.isMethod && s.asMethod.isPrimaryConstructor + ) + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramLists + } + } + params.flatten } - params.flatten - } } diff --git a/examples/pom-3.2_2.12.xml b/examples/pom-3.2_2.12.xml index b5267352..5f214b69 100644 --- a/examples/pom-3.2_2.12.xml +++ b/examples/pom-3.2_2.12.xml @@ -24,6 +24,11 @@ spark-sql_${scala.compat.version} ${spark3.version} + + org.apache.spark + spark-streaming_${scala.compat.version} + ${spark3.version} + diff --git a/kotlin-spark-api/3.2/pom_2.12.xml b/kotlin-spark-api/3.2/pom_2.12.xml index 756d9c2b..826547d2 100644 --- a/kotlin-spark-api/3.2/pom_2.12.xml +++ b/kotlin-spark-api/3.2/pom_2.12.xml @@ -36,6 +36,12 @@ ${spark3.version} provided + + org.apache.spark + spark-streaming_${scala.compat.version} + ${spark3.version} + provided + diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 1061a21a..fb1b5340 100644 --- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -21,11 +21,11 @@ package org.jetbrains.kotlinx.spark.api -import org.apache.hadoop.shaded.org.apache.commons.math3.exception.util.ArgUtils import org.apache.spark.SparkContext -import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.api.java.* import org.apache.spark.api.java.function.* import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.sql.* import org.apache.spark.sql.Encoders.* import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -33,17 +33,21 @@ import org.apache.spark.sql.streaming.GroupState import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* +import org.apache.spark.unsafe.types.CalendarInterval import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions import scala.Product import scala.Tuple2 +import scala.concurrent.duration.`Duration$` import scala.reflect.ClassTag -import scala.reflect.api.TypeTags.TypeTag +import scala.reflect.api.StandardDefinitions import java.beans.PropertyDescriptor import java.math.BigDecimal import java.sql.Date import java.sql.Timestamp +import java.time.Duration import java.time.Instant import java.time.LocalDate +import java.time.Period import java.util.* import java.util.concurrent.ConcurrentHashMap import kotlin.Any @@ -95,10 +99,12 @@ val ENCODERS: Map, Encoder<*>> = mapOf( String::class to STRING(), BigDecimal::class to DECIMAL(), Date::class to DATE(), - LocalDate::class to LOCALDATE(), // 3.0 only + LocalDate::class to LOCALDATE(), // 3.0+ Timestamp::class to TIMESTAMP(), - Instant::class to INSTANT(), // 3.0 only - ByteArray::class to BINARY() + Instant::class to INSTANT(), // 3.0+ + ByteArray::class to BINARY(), + Duration::class to DURATION(), // 3.2+ + Period::class to PERIOD(), // 3.2+ ) @@ -154,6 +160,18 @@ inline fun SparkSession.dsOf(vararg t: T): Dataset = inline fun List.toDS(spark: SparkSession): Dataset = spark.createDataset(this, encoder()) +/** + * Utility method to create dataset from RDD + */ +inline fun RDD.toDS(spark: SparkSession): Dataset = + spark.createDataset(this, encoder()) + +/** + * Utility method to create dataset from JavaRDD + */ +inline fun JavaRDDLike.toDS(spark: SparkSession): Dataset = + spark.createDataset(this.rdd(), encoder()) + /** * Main method of API, which gives you seamless integration with Spark: * It creates encoder for any given supported type T @@ -177,12 +195,16 @@ fun generateEncoder(type: KType, cls: KClass<*>): Encoder { } as Encoder } -private fun isSupportedClass(cls: KClass<*>): Boolean = - cls.isData - || cls.isSubclassOf(Map::class) - || cls.isSubclassOf(Iterable::class) - || cls.isSubclassOf(Product::class) - || cls.java.isArray +private fun isSupportedClass(cls: KClass<*>): Boolean = when { + cls == ByteArray::class -> false // uses binary encoder + cls.isData -> true + cls.isSubclassOf(Map::class) -> true + cls.isSubclassOf(Iterable::class) -> true + cls.isSubclassOf(Product::class) -> true + cls.java.isArray -> true + else -> false + } + private fun kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder { return ExpressionEncoder( @@ -1192,7 +1214,7 @@ fun schema(type: KType, map: Map = mapOf()): DataType { DoubleArray::class -> typeOf() BooleanArray::class -> typeOf() ShortArray::class -> typeOf() - ByteArray::class -> typeOf() +// ByteArray::class -> typeOf() handled by BinaryType else -> types.getValue(klass.typeParameters[0].name) } } else types.getValue(klass.typeParameters[0].name) @@ -1290,10 +1312,14 @@ private val knownDataTypes: Map, DataType> = mapOf( Float::class to DataTypes.FloatType, Double::class to DataTypes.DoubleType, String::class to DataTypes.StringType, - LocalDate::class to `DateType$`.`MODULE$`, - Date::class to `DateType$`.`MODULE$`, - Timestamp::class to `TimestampType$`.`MODULE$`, - Instant::class to `TimestampType$`.`MODULE$`, + LocalDate::class to DataTypes.DateType, + Date::class to DataTypes.DateType, + Timestamp::class to DataTypes.TimestampType, + Instant::class to DataTypes.TimestampType, + ByteArray::class to DataTypes.BinaryType, + Decimal::class to DecimalType.SYSTEM_DEFAULT(), + BigDecimal::class to DecimalType.SYSTEM_DEFAULT(), + CalendarInterval::class to DataTypes.CalendarIntervalType, ) private fun transitiveMerge(a: Map, b: Map): Map { diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt index 6188daae..98fdae8d 100644 --- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt +++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt @@ -20,6 +20,11 @@ package org.jetbrains.kotlinx.spark.api import org.apache.spark.SparkConf +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaRDDLike +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset import org.apache.spark.sql.SparkSession.Builder import org.apache.spark.sql.UDFRegistration import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR @@ -78,18 +83,38 @@ inline fun withSpark(builder: Builder, logLevel: SparkLogLevel = ERROR, func: KS KSparkSession(this).apply { sparkContext.setLogLevel(logLevel) func() + spark.stop() } } - .also { it.stop() } +} + +/** + * Wrapper for spark creation which copies params from [sparkConf]. + * + * @param sparkConf Sets a list of config options based on this. + * @param logLevel Control our logLevel. This overrides any user-defined log settings. + * @param func function which will be executed in context of [KSparkSession] (it means that `this` inside block will point to [KSparkSession]) + */ +@JvmOverloads +inline fun withSpark(sparkConf: SparkConf, logLevel: SparkLogLevel = ERROR, func: KSparkSession.() -> Unit) { + withSpark( + builder = SparkSession.builder().config(sparkConf), + logLevel = logLevel, + func = func, + ) } /** * This wrapper over [SparkSession] which provides several additional methods to create [org.apache.spark.sql.Dataset] */ -@Suppress("EXPERIMENTAL_FEATURE_WARNING", "unused") -inline class KSparkSession(val spark: SparkSession) { +class KSparkSession(val spark: SparkSession) { + + val sc: JavaSparkContext by lazy { JavaSparkContext(spark.sparkContext) } + inline fun List.toDS() = toDS(spark) inline fun Array.toDS() = spark.dsOf(*this) inline fun dsOf(vararg arg: T) = spark.dsOf(*arg) + inline fun RDD.toDS() = toDS(spark) + inline fun JavaRDDLike.toDS() = toDS(spark) val udf: UDFRegistration get() = spark.udf() } diff --git a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt index a4b2bdd7..af870038 100644 --- a/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt +++ b/kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/VarArities.kt @@ -22,32 +22,34 @@ */ package org.jetbrains.kotlinx.spark.api -data class Arity1(val _1: T1) -data class Arity2(val _1: T1, val _2: T2) -data class Arity3(val _1: T1, val _2: T2, val _3: T3) -data class Arity4(val _1: T1, val _2: T2, val _3: T3, val _4: T4) -data class Arity5(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5) -data class Arity6(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6) -data class Arity7(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7) -data class Arity8(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8) -data class Arity9(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9) -data class Arity10(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10) -data class Arity11(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11) -data class Arity12(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12) -data class Arity13(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13) -data class Arity14(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14) -data class Arity15(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15) -data class Arity16(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16) -data class Arity17(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17) -data class Arity18(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18) -data class Arity19(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19) -data class Arity20(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20) -data class Arity21(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21) -data class Arity22(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22) -data class Arity23(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23) -data class Arity24(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24) -data class Arity25(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25) -data class Arity26(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25, val _26: T26) +import java.io.Serializable + +data class Arity1(val _1: T1): Serializable +data class Arity2(val _1: T1, val _2: T2): Serializable +data class Arity3(val _1: T1, val _2: T2, val _3: T3): Serializable +data class Arity4(val _1: T1, val _2: T2, val _3: T3, val _4: T4): Serializable +data class Arity5(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5): Serializable +data class Arity6(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6): Serializable +data class Arity7(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7): Serializable +data class Arity8(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8): Serializable +data class Arity9(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9): Serializable +data class Arity10(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10): Serializable +data class Arity11(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11): Serializable +data class Arity12(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12): Serializable +data class Arity13(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13): Serializable +data class Arity14(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14): Serializable +data class Arity15(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15): Serializable +data class Arity16(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16): Serializable +data class Arity17(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17): Serializable +data class Arity18(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18): Serializable +data class Arity19(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19): Serializable +data class Arity20(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20): Serializable +data class Arity21(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21): Serializable +data class Arity22(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22): Serializable +data class Arity23(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23): Serializable +data class Arity24(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24): Serializable +data class Arity25(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25): Serializable +data class Arity26(val _1: T1, val _2: T2, val _3: T3, val _4: T4, val _5: T5, val _6: T6, val _7: T7, val _8: T8, val _9: T9, val _10: T10, val _11: T11, val _12: T12, val _13: T13, val _14: T14, val _15: T15, val _16: T16, val _17: T17, val _18: T18, val _19: T19, val _20: T20, val _21: T21, val _22: T22, val _23: T23, val _24: T24, val _25: T25, val _26: T26): Serializable fun c(_1: T1) = Arity1(_1) fun c(_1: T1, _2: T2) = Arity2(_1, _2) fun c(_1: T1, _2: T2, _3: T3) = Arity3(_1, _2, _3) diff --git a/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index ed784b13..fa320c39 100644 --- a/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -20,21 +20,32 @@ package org.jetbrains.kotlinx.spark.api/*- import ch.tutteli.atrium.api.fluent.en_GB.* import ch.tutteli.atrium.api.verbs.expect import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.should import io.kotest.matchers.shouldBe +import org.apache.spark.api.java.JavaDoubleRDD +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions.* import org.apache.spark.sql.streaming.GroupState import org.apache.spark.sql.streaming.GroupStateTimeout +import org.apache.spark.sql.types.Decimal +import org.apache.spark.unsafe.types.CalendarInterval import scala.Product import scala.Tuple1 import scala.Tuple2 import scala.Tuple3 import scala.collection.Seq import java.io.Serializable +import java.math.BigDecimal import java.sql.Date import java.sql.Timestamp +import java.time.Duration import java.time.Instant import java.time.LocalDate +import java.time.Period import kotlin.collections.Iterator import scala.collection.Iterator as ScalaIterator import scala.collection.Map as ScalaMap @@ -318,24 +329,96 @@ class ApiTest : ShouldSpec({ cogrouped.count() shouldBe 4 } should("handle LocalDate Datasets") { // uses encoder - val dataset: Dataset = dsOf(LocalDate.now(), LocalDate.now()) - dataset.show() + val dates = listOf(LocalDate.now(), LocalDate.now()) + val dataset: Dataset = dates.toDS() + dataset.collectAsList() shouldBe dates } should("handle Instant Datasets") { // uses encoder - val dataset: Dataset = dsOf(Instant.now(), Instant.now()) - dataset.show() + val instants = listOf(Instant.now(), Instant.now()) + val dataset: Dataset = instants.toDS() + dataset.collectAsList() shouldBe instants + } + should("Be able to serialize Instant") { // uses knownDataTypes + val instantPair = Instant.now() to Instant.now() + val dataset = dsOf(instantPair) + dataset.collectAsList() shouldBe listOf(instantPair) } should("be able to serialize Date") { // uses knownDataTypes - val dataset: Dataset> = dsOf(Date.valueOf("2020-02-10") to 5) - dataset.show() + val datePair = Date.valueOf("2020-02-10") to 5 + val dataset: Dataset> = dsOf(datePair) + dataset.collectAsList() shouldBe listOf(datePair) } should("handle Timestamp Datasets") { // uses encoder - val dataset = dsOf(Timestamp(0L)) - dataset.show() + val timeStamps = listOf(Timestamp(0L), Timestamp(1L)) + val dataset = timeStamps.toDS() + dataset.collectAsList() shouldBe timeStamps } should("be able to serialize Timestamp") { // uses knownDataTypes - val dataset = dsOf(Timestamp(0L) to 2) - dataset.show() + val timestampPair = Timestamp(0L) to 2 + val dataset = dsOf(timestampPair) + dataset.collectAsList() shouldBe listOf(timestampPair) + } + should("handle Duration Datasets") { // uses encoder + val dataset = dsOf(Duration.ZERO) + dataset.collectAsList() shouldBe listOf(Duration.ZERO) + } + should("handle Period Datasets") { // uses encoder + val periods = listOf(Period.ZERO, Period.ofDays(2)) + val dataset = periods.toDS() + + dataset.show(false) + + dataset.collectAsList().let { + it[0] shouldBe Period.ZERO + + // NOTE Spark truncates java.time.Period to months. + it[1] shouldBe Period.ofDays(0) + } + + } + should("handle binary datasets") { // uses encoder + val byteArray = "Hello there".encodeToByteArray() + val dataset = dsOf(byteArray) + dataset.collectAsList() shouldBe listOf(byteArray) + } + should("be able to serialize binary") { // uses knownDataTypes + val byteArrayTriple = c("Hello there".encodeToByteArray(), 1, intArrayOf(1, 2, 3)) + val dataset = dsOf(byteArrayTriple) + + val (a, b, c) = dataset.collectAsList().single() + a contentEquals "Hello there".encodeToByteArray() shouldBe true + b shouldBe 1 + c contentEquals intArrayOf(1, 2, 3) shouldBe true + } + should("be able to serialize Decimal") { // uses knownDataTypes + val decimalPair = c(Decimal().set(50), 12) + val dataset = dsOf(decimalPair) + dataset.collectAsList() shouldBe listOf(decimalPair) + } + should("handle BigDecimal datasets") { // uses encoder + val decimals = listOf(BigDecimal.ONE, BigDecimal.TEN) + val dataset = decimals.toDS() + dataset.collectAsList().let { (one, ten) -> + one.compareTo(BigDecimal.ONE) shouldBe 0 + ten.compareTo(BigDecimal.TEN) shouldBe 0 + } + } + should("be able to serialize BigDecimal") { // uses knownDataTypes + val decimalPair = c(BigDecimal.TEN, 12) + val dataset = dsOf(decimalPair) + val (a, b) = dataset.collectAsList().single() + a.compareTo(BigDecimal.TEN) shouldBe 0 + b shouldBe 12 + } + should("be able to serialize CalendarInterval") { // uses knownDataTypes + val calendarIntervalPair = CalendarInterval(1, 0, 0L) to 2 + val dataset = dsOf(calendarIntervalPair) + dataset.collectAsList() shouldBe listOf(calendarIntervalPair) + } + should("handle nullable datasets") { + val ints = listOf(1, 2, 3, null) + val dataset = ints.toDS() + dataset.collectAsList() shouldBe ints } should("Be able to serialize Scala Tuples including data classes") { val dataset = dsOf( @@ -366,20 +449,20 @@ class ApiTest : ShouldSpec({ val newDS1WithAs: Dataset = dataset.selectTyped( col("a").`as`(), ) - newDS1WithAs.show() + newDS1WithAs.collectAsList() val newDS2: Dataset> = dataset.selectTyped( col(SomeClass::a), // NOTE: this only works on 3.0, returning a data class with an array in it col(SomeClass::b), ) - newDS2.show() + newDS2.collectAsList() val newDS3: Dataset> = dataset.selectTyped( col(SomeClass::a), col(SomeClass::b), col(SomeClass::b), ) - newDS3.show() + newDS3.collectAsList() val newDS4: Dataset> = dataset.selectTyped( col(SomeClass::a), @@ -387,7 +470,7 @@ class ApiTest : ShouldSpec({ col(SomeClass::b), col(SomeClass::b), ) - newDS4.show() + newDS4.collectAsList() val newDS5: Dataset> = dataset.selectTyped( col(SomeClass::a), @@ -396,7 +479,7 @@ class ApiTest : ShouldSpec({ col(SomeClass::b), col(SomeClass::b), ) - newDS5.show() + newDS5.collectAsList() } should("Access columns using invoke on datasets") { val dataset = dsOf( @@ -449,19 +532,18 @@ class ApiTest : ShouldSpec({ dataset(SomeOtherClass::a), col(SomeOtherClass::c), ) - b.show() + b.collectAsList() } should("Handle some where queries using column operator functions") { val dataset = dsOf( SomeOtherClass(intArrayOf(1, 2, 3), 4, true), SomeOtherClass(intArrayOf(4, 3, 2), 1, true), ) - dataset.show() + dataset.collectAsList() val column = col("b").`as`() val b = dataset.where(column gt 3 and col(SomeOtherClass::c)) - b.show() b.count() shouldBe 1 } @@ -470,21 +552,51 @@ class ApiTest : ShouldSpec({ listOf(SomeClass(intArrayOf(1, 2, 3), 4)), listOf(SomeClass(intArrayOf(3, 2, 1), 0)), ) - dataset.show() + + val (first, second) = dataset.collectAsList() + + first.single().let { (a, b) -> + a.contentEquals(intArrayOf(1, 2, 3)) shouldBe true + b shouldBe 4 + } + second.single().let { (a, b) -> + a.contentEquals(intArrayOf(3, 2, 1)) shouldBe true + b shouldBe 0 + } } should("Be able to serialize arrays of data classes") { val dataset = dsOf( arrayOf(SomeClass(intArrayOf(1, 2, 3), 4)), arrayOf(SomeClass(intArrayOf(3, 2, 1), 0)), ) - dataset.show() + + val (first, second) = dataset.collectAsList() + + first.single().let { (a, b) -> + a.contentEquals(intArrayOf(1, 2, 3)) shouldBe true + b shouldBe 4 + } + second.single().let { (a, b) -> + a.contentEquals(intArrayOf(3, 2, 1)) shouldBe true + b shouldBe 0 + } } should("Be able to serialize lists of tuples") { val dataset = dsOf( listOf(Tuple2(intArrayOf(1, 2, 3), 4)), listOf(Tuple2(intArrayOf(3, 2, 1), 0)), ) - dataset.show() + + val (first, second) = dataset.collectAsList() + + first.single().let { + it._1().contentEquals(intArrayOf(1, 2, 3)) shouldBe true + it._2() shouldBe 4 + } + second.single().let { + it._1().contentEquals(intArrayOf(3, 2, 1)) shouldBe true + it._2() shouldBe 0 + } } should("Allow simple forEachPartition in datasets") { val dataset = dsOf( @@ -593,10 +705,73 @@ class ApiTest : ShouldSpec({ it.nullable() shouldBe true } } + should("Convert Scala RDD to Dataset") { + val rdd0: RDD = sc.parallelize( + listOf(1, 2, 3, 4, 5, 6) + ).rdd() + val dataset0: Dataset = rdd0.toDS() + + dataset0.toList() shouldBe listOf(1, 2, 3, 4, 5, 6) + } + + should("Convert a JavaRDD to a Dataset") { + val rdd1: JavaRDD = sc.parallelize( + listOf(1, 2, 3, 4, 5, 6) + ) + val dataset1: Dataset = rdd1.toDS() + + dataset1.toList() shouldBe listOf(1, 2, 3, 4, 5, 6) + } + should("Convert JavaDoubleRDD to Dataset") { + + // JavaDoubleRDD + val rdd2: JavaDoubleRDD = sc.parallelizeDoubles( + listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + ) + val dataset2: Dataset = rdd2.toDS() + + dataset2.toList() shouldBe listOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + } + should("Convert JavaPairRDD to Dataset") { + val rdd3: JavaPairRDD = sc.parallelizePairs( + listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0)) + ) + val dataset3: Dataset> = rdd3.toDS() + + dataset3.toList>() shouldBe listOf(Tuple2(1, 1.0), Tuple2(2, 2.0), Tuple2(3, 3.0)) + } + should("Convert Kotlin Serializable data class RDD to Dataset") { + val rdd4 = sc.parallelize( + listOf(SomeClass(intArrayOf(1, 2), 0)) + ) + val dataset4 = rdd4.toDS() + + dataset4.toList().first().let { (a, b) -> + a contentEquals intArrayOf(1, 2) shouldBe true + b shouldBe 0 + } + } + should("Convert Arity RDD to Dataset") { + val rdd5 = sc.parallelize( + listOf(c(1.0, 4)) + ) + val dataset5 = rdd5.toDS() + + dataset5.toList>() shouldBe listOf(c(1.0, 4)) + } + should("Convert List RDD to Dataset") { + val rdd6 = sc.parallelize( + listOf(listOf(1, 2, 3), listOf(4, 5, 6)) + ) + val dataset6 = rdd6.toDS() + + dataset6.toList>() shouldBe listOf(listOf(1, 2, 3), listOf(4, 5, 6)) + } } } }) + data class DataClassWithTuple(val tuple: T) data class LonLat(val lon: Double, val lat: Double) @@ -626,5 +801,5 @@ data class ComplexEnumDataClass( data class NullFieldAbleDataClass( val optionList: List?, - val optionMap: Map? -) \ No newline at end of file + val optionMap: Map?, +) diff --git a/pom.xml b/pom.xml index 47043737..0df3adac 100644 --- a/pom.xml +++ b/pom.xml @@ -15,7 +15,7 @@ 0.16.0 4.6.0 1.0.1 - 3.2.0 + 3.2.1 2.10.0 @@ -32,6 +32,7 @@ 3.0.0-M5 1.6.8 4.5.6 + official