diff --git a/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala b/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala index 2eb3d974..cf89bc3e 100644 --- a/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala +++ b/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala @@ -336,6 +336,10 @@ object KotlinReflection extends KotlinReflection { mirror.runtimeClass(t.typeSymbol.asClass) ) + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => + createDeserializerForTypesSupportValueOf( + createDeserializerForString(path, returnNullable = false), Class.forName(t.toString)) + case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). getConstructor().newInstance() @@ -612,7 +616,7 @@ object KotlinReflection extends KotlinReflection { case _: StringType => val clsName = getClassNameFromType(typeOf[String]) val newPath = walkedTypePath.recordArray(clsName) - createSerializerForMapObjects(input, ObjectType(classOf[String]), + createSerializerForMapObjects(input, ObjectType(Class.forName(getClassNameFromType(elementType))), serializerFor(_, elementType, newPath, seenTypeSet)) @@ -718,6 +722,10 @@ object KotlinReflection extends KotlinReflection { case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => createSerializerForBoolean(inputObject) case t if isSubtype(t, localTypeOf[Boolean]) => createSerializerForBoolean(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => + createSerializerForString( + Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) + case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 20fbfea0..4bc4811b 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -858,6 +858,9 @@ fun schema(type: KType, map: Map = mapOf()): DataType { it.first.name to it.second.type!! }.toMap()) return when { + klass.isSubclassOf(Enum::class) -> { + KSimpleTypeWrapper(DataTypes.StringType, klass.java, type.isMarkedNullable) + } klass.isSubclassOf(Iterable::class) || klass.java.isArray -> { val listParam = if (klass.java.isArray) { when (klass) { diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 73d3c44d..8da89fb9 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -473,6 +473,37 @@ class ApiTest : ShouldSpec({ dataset.sort(SomeClass::a, SomeClass::b) dataset.takeAsList(1).first().b shouldBe 2 } + should("Generate encoder correctly with complex enum data class") { + val dataset: Dataset = + dsOf( + ComplexEnumDataClass( + 1, + "string", + listOf("1", "2"), + SomeEnum.A, + SomeOtherEnum.C, + listOf(SomeEnum.A, SomeEnum.B), + listOf(SomeOtherEnum.C, SomeOtherEnum.D), + arrayOf(SomeEnum.A, SomeEnum.B), + arrayOf(SomeOtherEnum.C, SomeOtherEnum.D), + mapOf(SomeEnum.A to SomeOtherEnum.C) + ) + ) + + dataset.show(false) + val first = dataset.takeAsList(1).first() + + first.int shouldBe 1 + first.string shouldBe "string" + first.strings shouldBe listOf("1","2") + first.someEnum shouldBe SomeEnum.A + first.someOtherEnum shouldBe SomeOtherEnum.C + first.someEnums shouldBe listOf(SomeEnum.A, SomeEnum.B) + first.someOtherEnums shouldBe listOf(SomeOtherEnum.C, SomeOtherEnum.D) + first.someEnumArray shouldBe arrayOf(SomeEnum.A, SomeEnum.B) + first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D) + first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C) + } } } }) @@ -505,3 +536,20 @@ data class Test(val id: Long, val data: Array>) { data class SomeClass(val a: IntArray, val b: Int) : Serializable data class SomeOtherClass(val a: IntArray, val b: Int, val c: Boolean) : Serializable + +enum class SomeEnum { A, B } + +enum class SomeOtherEnum(val value: Int) { C(1), D(2) } + +data class ComplexEnumDataClass( + val int: Int, + val string: String, + val strings: List, + val someEnum: SomeEnum, + val someOtherEnum: SomeOtherEnum, + val someEnums: List, + val someOtherEnums: List, + val someEnumArray: Array, + val someOtherArray: Array, + val enumMap: Map +) diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 5735fe06..f568c4c4 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -851,6 +851,9 @@ fun schema(type: KType, map: Map = mapOf()): DataType { it.first.name to it.second.type!! }.toMap()) return when { + klass.isSubclassOf(Enum::class) -> { + KSimpleTypeWrapper(DataTypes.StringType, klass.java, type.isMarkedNullable) + } klass.isSubclassOf(Iterable::class) || klass.java.isArray -> { val listParam = if (klass.java.isArray) { when (klass) { diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 43a67f2a..36cb35fc 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -515,6 +515,37 @@ class ApiTest : ShouldSpec({ dataset.sort(SomeClass::a, SomeClass::b) dataset.takeAsList(1).first().b shouldBe 2 } + should("Generate encoder correctly with complex enum data class") { + val dataset: Dataset = + dsOf( + ComplexEnumDataClass( + 1, + "string", + listOf("1", "2"), + SomeEnum.A, + SomeOtherEnum.C, + listOf(SomeEnum.A, SomeEnum.B), + listOf(SomeOtherEnum.C, SomeOtherEnum.D), + arrayOf(SomeEnum.A, SomeEnum.B), + arrayOf(SomeOtherEnum.C, SomeOtherEnum.D), + mapOf(SomeEnum.A to SomeOtherEnum.C) + ) + ) + + dataset.show(false) + val first = dataset.takeAsList(1).first() + + first.int shouldBe 1 + first.string shouldBe "string" + first.strings shouldBe listOf("1","2") + first.someEnum shouldBe SomeEnum.A + first.someOtherEnum shouldBe SomeOtherEnum.C + first.someEnums shouldBe listOf(SomeEnum.A, SomeEnum.B) + first.someOtherEnums shouldBe listOf(SomeOtherEnum.C, SomeOtherEnum.D) + first.someEnumArray shouldBe arrayOf(SomeEnum.A, SomeEnum.B) + first.someOtherArray shouldBe arrayOf(SomeOtherEnum.C, SomeOtherEnum.D) + first.enumMap shouldBe mapOf(SomeEnum.A to SomeOtherEnum.C) + } } } }) @@ -527,3 +558,20 @@ data class LonLat(val lon: Double, val lat: Double) data class SomeClass(val a: IntArray, val b: Int) : Serializable data class SomeOtherClass(val a: IntArray, val b: Int, val c: Boolean) : Serializable + +enum class SomeEnum { A, B } + +enum class SomeOtherEnum(val value: Int) { C(1), D(2) } + +data class ComplexEnumDataClass( + val int: Int, + val string: String, + val strings: List, + val someEnum: SomeEnum, + val someOtherEnum: SomeOtherEnum, + val someEnums: List, + val someOtherEnums: List, + val someEnumArray: Array, + val someOtherArray: Array, + val enumMap: Map +) \ No newline at end of file