Skip to content

Commit a888ea4

Browse files
authored
Add enum support for ApiV1Kt.generateEncoder (#99)
Co-authored-by: can wang <[email protected]>
1 parent b43e3f1 commit a888ea4

File tree

5 files changed

+111
-1
lines changed
  • core/3.0/src/main/scala/org/apache/spark/sql
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

5 files changed

+111
-1
lines changed

core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ object KotlinReflection extends KotlinReflection {
336336
mirror.runtimeClass(t.typeSymbol.asClass)
337337
)
338338

339+
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
340+
createDeserializerForTypesSupportValueOf(
341+
createDeserializerForString(path, returnNullable = false), Class.forName(t.toString))
342+
339343
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
340344
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
341345
getConstructor().newInstance()
@@ -612,7 +616,7 @@ object KotlinReflection extends KotlinReflection {
612616
case _: StringType =>
613617
val clsName = getClassNameFromType(typeOf[String])
614618
val newPath = walkedTypePath.recordArray(clsName)
615-
createSerializerForMapObjects(input, ObjectType(classOf[String]),
619+
createSerializerForMapObjects(input, ObjectType(Class.forName(getClassNameFromType(elementType))),
616620
serializerFor(_, elementType, newPath, seenTypeSet))
617621

618622

@@ -718,6 +722,10 @@ object KotlinReflection extends KotlinReflection {
718722
case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => createSerializerForBoolean(inputObject)
719723
case t if isSubtype(t, localTypeOf[Boolean]) => createSerializerForBoolean(inputObject)
720724

725+
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
726+
createSerializerForString(
727+
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))
728+
721729
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
722730
val udt = getClassFromType(t)
723731
.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance()

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,9 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
858858
it.first.name to it.second.type!!
859859
}.toMap())
860860
return when {
861+
klass.isSubclassOf(Enum::class) -> {
862+
KSimpleTypeWrapper(DataTypes.StringType, klass.java, type.isMarkedNullable)
863+
}
861864
klass.isSubclassOf(Iterable::class) || klass.java.isArray -> {
862865
val listParam = if (klass.java.isArray) {
863866
when (klass) {

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,37 @@ class ApiTest : ShouldSpec({
473473
dataset.sort(SomeClass::a, SomeClass::b)
474474
dataset.takeAsList(1).first().b shouldBe 2
475475
}
476+
should("Generate encoder correctly with complex enum data class") {
477+
val dataset: Dataset<ComplexEnumDataClass> =
478+
dsOf(
479+
ComplexEnumDataClass(
480+
1,
481+
"string",
482+
listOf("1", "2"),
483+
SomeEnum.A,
484+
SomeOtherEnum.C,
485+
listOf(SomeEnum.A, SomeEnum.B),
486+
listOf(SomeOtherEnum.C, SomeOtherEnum.D),
487+
arrayOf(SomeEnum.A, SomeEnum.B),
488+
arrayOf(SomeOtherEnum.C, SomeOtherEnum.D),
489+
mapOf(SomeEnum.A to SomeOtherEnum.C)
490+
)
491+
)
492+
493+
dataset.show(false)
494+
val first = dataset.takeAsList(1).first()
495+
496+
first.int shouldBe 1
497+
first.string shouldBe "string"
498+
first.strings shouldBe listOf("1","2")
499+
first.someEnum shouldBe SomeEnum.A
500+
first.someOtherEnum shouldBe SomeOtherEnum.C
501+
first.someEnums shouldBe listOf(SomeEnum.A, SomeEnum.B)
502+
first.someOtherEnums shouldBe listOf(SomeOtherEnum.C, SomeOtherEnum.D)
503+
first.someEnumArray shouldBe arrayOf(SomeEnum.A, SomeEnum.B)
504+
first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D)
505+
first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C)
506+
}
476507
}
477508
}
478509
})
@@ -505,3 +536,20 @@ data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>) {
505536
data class SomeClass(val a: IntArray, val b: Int) : Serializable
506537

507538
data class SomeOtherClass(val a: IntArray, val b: Int, val c: Boolean) : Serializable
539+
540+
enum class SomeEnum { A, B }
541+
542+
enum class SomeOtherEnum(val value: Int) { C(1), D(2) }
543+
544+
data class ComplexEnumDataClass(
545+
val int: Int,
546+
val string: String,
547+
val strings: List<String>,
548+
val someEnum: SomeEnum,
549+
val someOtherEnum: SomeOtherEnum,
550+
val someEnums: List<SomeEnum>,
551+
val someOtherEnums: List<SomeOtherEnum>,
552+
val someEnumArray: Array<SomeEnum>,
553+
val someOtherArray: Array<SomeOtherEnum>,
554+
val enumMap: Map<SomeEnum, SomeOtherEnum>
555+
)

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,9 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
851851
it.first.name to it.second.type!!
852852
}.toMap())
853853
return when {
854+
klass.isSubclassOf(Enum::class) -> {
855+
KSimpleTypeWrapper(DataTypes.StringType, klass.java, type.isMarkedNullable)
856+
}
854857
klass.isSubclassOf(Iterable::class) || klass.java.isArray -> {
855858
val listParam = if (klass.java.isArray) {
856859
when (klass) {

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,37 @@ class ApiTest : ShouldSpec({
515515
dataset.sort(SomeClass::a, SomeClass::b)
516516
dataset.takeAsList(1).first().b shouldBe 2
517517
}
518+
should("Generate encoder correctly with complex enum data class") {
519+
val dataset: Dataset<ComplexEnumDataClass> =
520+
dsOf(
521+
ComplexEnumDataClass(
522+
1,
523+
"string",
524+
listOf("1", "2"),
525+
SomeEnum.A,
526+
SomeOtherEnum.C,
527+
listOf(SomeEnum.A, SomeEnum.B),
528+
listOf(SomeOtherEnum.C, SomeOtherEnum.D),
529+
arrayOf(SomeEnum.A, SomeEnum.B),
530+
arrayOf(SomeOtherEnum.C, SomeOtherEnum.D),
531+
mapOf(SomeEnum.A to SomeOtherEnum.C)
532+
)
533+
)
534+
535+
dataset.show(false)
536+
val first = dataset.takeAsList(1).first()
537+
538+
first.int shouldBe 1
539+
first.string shouldBe "string"
540+
first.strings shouldBe listOf("1","2")
541+
first.someEnum shouldBe SomeEnum.A
542+
first.someOtherEnum shouldBe SomeOtherEnum.C
543+
first.someEnums shouldBe listOf(SomeEnum.A, SomeEnum.B)
544+
first.someOtherEnums shouldBe listOf(SomeOtherEnum.C, SomeOtherEnum.D)
545+
first.someEnumArray shouldBe arrayOf(SomeEnum.A, SomeEnum.B)
546+
first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D)
547+
first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C)
548+
}
518549
}
519550
}
520551
})
@@ -527,3 +558,20 @@ data class LonLat(val lon: Double, val lat: Double)
527558
data class SomeClass(val a: IntArray, val b: Int) : Serializable
528559

529560
data class SomeOtherClass(val a: IntArray, val b: Int, val c: Boolean) : Serializable
561+
562+
enum class SomeEnum { A, B }
563+
564+
enum class SomeOtherEnum(val value: Int) { C(1), D(2) }
565+
566+
data class ComplexEnumDataClass(
567+
val int: Int,
568+
val string: String,
569+
val strings: List<String>,
570+
val someEnum: SomeEnum,
571+
val someOtherEnum: SomeOtherEnum,
572+
val someEnums: List<SomeEnum>,
573+
val someOtherEnums: List<SomeOtherEnum>,
574+
val someEnumArray: Array<SomeEnum>,
575+
val someOtherArray: Array<SomeOtherEnum>,
576+
val enumMap: Map<SomeEnum, SomeOtherEnum>
577+
)

0 commit comments

Comments
 (0)