Skip to content

[Fix] Encoding isSomething names in data classes #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils

import java.beans.{Introspector, PropertyDescriptor}
import java.lang.Exception
import java.lang.reflect.Method


/**
Expand Down Expand Up @@ -212,11 +213,11 @@ object KotlinReflection extends KotlinReflection {
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def deserializerFor(
tpe: `Type`,
path: Expression,
walkedTypePath: WalkedTypePath,
predefinedDt: Option[DataTypeWithClass] = None
): Expression = cleanUpReflectionObjects {
tpe: `Type`,
path: Expression,
walkedTypePath: WalkedTypePath,
predefinedDt: Option[DataTypeWithClass] = None
): Expression = cleanUpReflectionObjects {
baseType(tpe) match {

//<editor-fold desc="Description">
Expand Down Expand Up @@ -685,18 +686,18 @@ object KotlinReflection extends KotlinReflection {
* internal representation.
*/
private def serializerFor(
inputObject: Expression,
tpe: `Type`,
walkedTypePath: WalkedTypePath,
seenTypeSet: Set[`Type`] = Set.empty,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = cleanUpReflectionObjects {
inputObject: Expression,
tpe: `Type`,
walkedTypePath: WalkedTypePath,
seenTypeSet: Set[`Type`] = Set.empty,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = cleanUpReflectionObjects {

def toCatalystArray(
input: Expression,
elementType: `Type`,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = {
input: Expression,
elementType: `Type`,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = {
val dataType = predefinedDt
.map(_.dt)
.getOrElse {
Expand All @@ -705,7 +706,7 @@ object KotlinReflection extends KotlinReflection {

dataType match {

case dt @ (MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => {
case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => {
val clsName = getClassNameFromType(elementType)
val newPath = walkedTypePath.recordArray(clsName)
createSerializerForMapObjects(
Expand All @@ -726,7 +727,7 @@ object KotlinReflection extends KotlinReflection {
// case dt: ByteType =>
// createSerializerForPrimitiveArray(input, dt)

case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => {
case dt@(BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => {
val cls = input.dataType.asInstanceOf[ObjectType].cls
if (cls.isArray && cls.getComponentType.isPrimitive) {
createSerializerForPrimitiveArray(input, dt)
Expand Down Expand Up @@ -945,11 +946,11 @@ object KotlinReflection extends KotlinReflection {
// Kotlin specific cases
case t if predefinedDt.isDefined => {

// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }
// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }

predefinedDt.get match {

Expand All @@ -959,18 +960,20 @@ object KotlinReflection extends KotlinReflection {
val properties = getJavaBeanReadableProperties(cls)
val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField])
val fields: Array[(String, Expression)] = structFields.map { structField =>
val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName)
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${
structField.name
} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}"
)
val maybeProp = properties.find {
_.getName == structField.getterName
}
if (maybeProp.isEmpty)
throw new IllegalArgumentException(
s"Field ${structField.name} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}"
)
val fieldName = structField.name
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]

val fieldValue = Invoke(
inputObject,
maybeProp.get.getReadMethod.getName,
maybeProp.get.getName,
inferExternalType(propClass),
returnNullable = structField.nullable
)
Expand Down Expand Up @@ -1124,11 +1127,14 @@ object KotlinReflection extends KotlinReflection {
)
}

def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[Method] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
beanInfo
.getMethodDescriptors
.filter { it => it.getName.startsWith("is") || it.getName.startsWith("get") }
.filterNot { _.getName == "getClass" }
.filterNot { _.getName == "getDeclaringClass" }
.map { _.getMethod }
}

/*
Expand Down Expand Up @@ -1296,7 +1302,7 @@ object KotlinReflection extends KotlinReflection {
val params = method.typeSignature.paramLists.head
// Check that the needed params are the same length and of matching types
params.size == paramTypes.tail.size &&
params.zip(paramTypes.tail).forall { case(ps, pc) =>
params.zip(paramTypes.tail).forall { case (ps, pc) =>
ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
}
}.map { applyMethodSymbol =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,24 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
}

klass.isData -> {

val structType = StructType(
klass
.primaryConstructor!!
.parameters
.filter { it.findAnnotation<Transient>() == null }
.map {
val projectedType = types[it.type.toString()] ?: it.type

val readMethodName = when {
it.name!!.startsWith("is") -> it.name!!
else -> "get${it.name!!.replaceFirstChar { it.uppercase() }}"
}

val propertyDescriptor = PropertyDescriptor(
/* propertyName = */ it.name,
/* beanClass = */ klass.java,
/* readMethodName = */ "is" + it.name?.replaceFirstChar {
if (it.isLowerCase()) it.titlecase(Locale.getDefault())
else it.toString()
},
/* readMethodName = */ readMethodName,
/* writeMethodName = */ null
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ class EncodingTest : ShouldSpec({
val dataset = ints.toDS()
dataset.collectAsList() shouldBe ints
}

should("handle data classes with isSomething") {
val dataClasses = listOf(
IsSomethingClass(true, false, true, 1.0, 2.0, 0.0),
IsSomethingClass(false, true, true, 1.0, 2.0, 0.0),
)
val dataset = dataClasses.toDS().showDS()
dataset.collectAsList() shouldBe dataClasses
}
}
}
context("known dataTypes") {
Expand Down Expand Up @@ -174,6 +183,16 @@ class EncodingTest : ShouldSpec({
asList.first() shouldBe t("a", t("a", 1, LonLat(1.0, 1.0)))
}

should("Be able to serialize Scala Tuples including isSomething data classes") {
val dataset = dsOf(
t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0))),
t("b", t("b", 2, IsSomethingClass(false, true, true, 1.0, 2.0, 0.0))),
)
dataset.show()
val asList = dataset.takeAsList(2)
asList.first() shouldBe t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0)))
}

should("Be able to serialize data classes with tuples") {
val dataset = dsOf(
DataClassWithTuple(t(5L, "test", t(""))),
Expand Down Expand Up @@ -495,6 +514,15 @@ class EncodingTest : ShouldSpec({
}
})

data class IsSomethingClass(
val enabled: Boolean,
val isEnabled: Boolean,
val getEnabled: Boolean,
val double: Double,
val isDouble: Double,
val getDouble: Double
)

data class DataClassWithTuple<T : Product>(val tuple: T)

data class LonLat(val lon: Double, val lat: Double)
Expand Down