Skip to content

Commit aa6d3e5

Browse files
authored
Merge pull request #147 from JetBrains/scala-case-class-encoding
Rewrote product encoding to support scala case classes
2 parents d62e3af + 2a409c6 commit aa6d3e5

File tree

4 files changed

+112
-9
lines changed

4 files changed

+112
-9
lines changed

core/3.2/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
2828
import org.apache.spark.sql.catalyst.expressions.objects._
2929
import org.apache.spark.sql.catalyst.expressions.{Expression, _}
3030
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
31-
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePath}
31+
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, InternalRow, ScalaReflection, WalkedTypePath}
3232
import org.apache.spark.sql.types._
3333
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3434
import org.apache.spark.util.Utils
@@ -42,11 +42,12 @@ import java.lang.Exception
4242
* for classes whose fields are entirely defined by constructor params but should not be
4343
* case classes.
4444
*/
45-
trait DefinedByConstructorParams
45+
//trait DefinedByConstructorParams
4646

4747
/**
4848
* KotlinReflection is heavily inspired by ScalaReflection and even extends it just to add several methods
4949
*/
50+
//noinspection RedundantBlock
5051
object KotlinReflection extends KotlinReflection {
5152
/**
5253
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
@@ -916,9 +917,18 @@ object KotlinReflection extends KotlinReflection {
916917
}
917918
//</editor-fold>
918919

919-
case _ if predefinedDt.isDefined => {
920+
// Kotlin specific cases
921+
case t if predefinedDt.isDefined => {
922+
923+
// if (seenTypeSet.contains(t)) {
924+
// throw new UnsupportedOperationException(
925+
// s"cannot have circular references in class, but got the circular reference of class $t"
926+
// )
927+
// }
928+
920929
predefinedDt.get match {
921930

931+
// Kotlin data class
922932
case dataType: KDataTypeWrapper => {
923933
val cls = dataType.cls
924934
val properties = getJavaBeanReadableProperties(cls)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package org.jetbrains.kotlinx.spark.extensions
2+
3+
case class DemoCaseClass[T](a: Int, b: T)

kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,18 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
271271
KDataTypeWrapper(structType, klass.java, true)
272272
}
273273
klass.isSubclassOf(Product::class) -> {
274-
val params = type.arguments.mapIndexed { i, it ->
275-
"_${i + 1}" to it.type!!
274+
275+
// create map from T1, T2 to Int, String etc.
276+
val typeMap = klass.constructors.first().typeParameters.map { it.name }
277+
.zip(
278+
type.arguments.map { it.type }
279+
)
280+
.toMap()
281+
282+
// collect params by name and actual type
283+
val params = klass.constructors.first().parameters.map {
284+
val typeName = it.type.toString().replace("!", "")
285+
it.name to (typeMap[typeName] ?: it.type)
276286
}
277287

278288
val structType = DataTypes.createStructType(

kotlin-spark-api/3.2/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ import org.apache.spark.sql.Dataset
2727
import org.apache.spark.sql.types.Decimal
2828
import org.apache.spark.unsafe.types.CalendarInterval
2929
import org.jetbrains.kotlinx.spark.api.tuples.*
30-
import scala.Product
31-
import scala.Tuple1
32-
import scala.Tuple2
33-
import scala.Tuple3
30+
import org.jetbrains.kotlinx.spark.extensions.DemoCaseClass
31+
import scala.*
3432
import java.math.BigDecimal
3533
import java.sql.Date
3634
import java.sql.Timestamp
@@ -180,6 +178,88 @@ class EncodingTest : ShouldSpec({
180178
context("schema") {
181179
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
182180

181+
should("handle Scala Case class datasets") {
182+
val caseClasses = listOf(
183+
DemoCaseClass(1, "1"),
184+
DemoCaseClass(2, "2"),
185+
DemoCaseClass(3, "3"),
186+
)
187+
val dataset = caseClasses.toDS()
188+
dataset.show()
189+
dataset.collectAsList() shouldBe caseClasses
190+
}
191+
192+
should("handle Scala Case class with data class datasets") {
193+
val caseClasses = listOf(
194+
DemoCaseClass(1, "1" to 1L),
195+
DemoCaseClass(2, "2" to 2L),
196+
DemoCaseClass(3, "3" to 3L),
197+
)
198+
val dataset = caseClasses.toDS()
199+
dataset.show()
200+
dataset.collectAsList() shouldBe caseClasses
201+
}
202+
203+
should("handle data class with Scala Case class datasets") {
204+
val caseClasses = listOf(
205+
1 to DemoCaseClass(1, "1"),
206+
2 to DemoCaseClass(2, "2"),
207+
3 to DemoCaseClass(3, "3"),
208+
)
209+
val dataset = caseClasses.toDS()
210+
dataset.show()
211+
dataset.collectAsList() shouldBe caseClasses
212+
}
213+
214+
should("handle data class with Scala Case class & deeper datasets") {
215+
val caseClasses = listOf(
216+
1 to DemoCaseClass(1, "1" to DemoCaseClass(1, 1.0)),
217+
2 to DemoCaseClass(2, "2" to DemoCaseClass(2, 2.0)),
218+
3 to DemoCaseClass(3, "3" to DemoCaseClass(3, 3.0)),
219+
)
220+
val dataset = caseClasses.toDS()
221+
dataset.show()
222+
dataset.collectAsList() shouldBe caseClasses
223+
}
224+
225+
226+
xshould("handle Scala Option datasets") {
227+
val caseClasses = listOf(Some(1), Some(2), Some(3))
228+
val dataset = caseClasses.toDS()
229+
dataset.show()
230+
dataset.collectAsList() shouldBe caseClasses
231+
}
232+
233+
xshould("handle Scala Option Option datasets") {
234+
val caseClasses = listOf(
235+
Some(Some(1)),
236+
Some(Some(2)),
237+
Some(Some(3)),
238+
)
239+
val dataset = caseClasses.toDS()
240+
dataset.collectAsList() shouldBe caseClasses
241+
}
242+
243+
xshould("handle data class Scala Option datasets") {
244+
val caseClasses = listOf(
245+
Some(1) to Some(2),
246+
Some(3) to Some(4),
247+
Some(5) to Some(6),
248+
)
249+
val dataset = caseClasses.toDS()
250+
dataset.collectAsList() shouldBe caseClasses
251+
}
252+
253+
xshould("handle Scala Option data class datasets") {
254+
val caseClasses = listOf(
255+
Some(1 to 2),
256+
Some(3 to 4),
257+
Some(5 to 6),
258+
)
259+
val dataset = caseClasses.toDS()
260+
dataset.collectAsList() shouldBe caseClasses
261+
}
262+
183263
should("collect data classes with doubles correctly") {
184264
val ll1 = LonLat(1.0, 2.0)
185265
val ll2 = LonLat(3.0, 4.0)

0 commit comments

Comments
 (0)