Skip to content

Rewrote product encoding to support scala case classes #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.expressions.{Expression, _}
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, InternalRow, ScalaReflection, WalkedTypePath}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
Expand All @@ -42,11 +42,12 @@ import java.lang.Exception
* for classes whose fields are entirely defined by constructor params but should not be
* case classes.
*/
trait DefinedByConstructorParams
//trait DefinedByConstructorParams

/**
* KotlinReflection is heavily inspired by ScalaReflection and even extends it just to add several methods
*/
//noinspection RedundantBlock
object KotlinReflection extends KotlinReflection {
/**
* Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
Expand Down Expand Up @@ -916,9 +917,18 @@ object KotlinReflection extends KotlinReflection {
}
//</editor-fold>

case _ if predefinedDt.isDefined => {
// Kotlin specific cases
case t if predefinedDt.isDefined => {

// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }

predefinedDt.get match {

// Kotlin data class
case dataType: KDataTypeWrapper => {
val cls = dataType.cls
val properties = getJavaBeanReadableProperties(cls)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package org.jetbrains.kotlinx.spark.extensions

case class DemoCaseClass[T](a: Int, b: T)
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,18 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
KDataTypeWrapper(structType, klass.java, true)
}
klass.isSubclassOf(Product::class) -> {
val params = type.arguments.mapIndexed { i, it ->
"_${i + 1}" to it.type!!

// create map from T1, T2 to Int, String etc.
val typeMap = klass.constructors.first().typeParameters.map { it.name }
.zip(
type.arguments.map { it.type }
)
.toMap()

// collect params by name and actual type
val params = klass.constructors.first().parameters.map {
val typeName = it.type.toString().replace("!", "")
it.name to (typeMap[typeName] ?: it.type)
}

val structType = DataTypes.createStructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval
import org.jetbrains.kotlinx.spark.api.tuples.*
import scala.Product
import scala.Tuple1
import scala.Tuple2
import scala.Tuple3
import org.jetbrains.kotlinx.spark.extensions.DemoCaseClass
import scala.*
import java.math.BigDecimal
import java.sql.Date
import java.sql.Timestamp
Expand Down Expand Up @@ -180,6 +178,88 @@ class EncodingTest : ShouldSpec({
context("schema") {
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {

should("handle Scala Case class datasets") {
val caseClasses = listOf(
DemoCaseClass(1, "1"),
DemoCaseClass(2, "2"),
DemoCaseClass(3, "3"),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle Scala Case class with data class datasets") {
val caseClasses = listOf(
DemoCaseClass(1, "1" to 1L),
DemoCaseClass(2, "2" to 2L),
DemoCaseClass(3, "3" to 3L),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle data class with Scala Case class datasets") {
val caseClasses = listOf(
1 to DemoCaseClass(1, "1"),
2 to DemoCaseClass(2, "2"),
3 to DemoCaseClass(3, "3"),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

should("handle data class with Scala Case class & deeper datasets") {
val caseClasses = listOf(
1 to DemoCaseClass(1, "1" to DemoCaseClass(1, 1.0)),
2 to DemoCaseClass(2, "2" to DemoCaseClass(2, 2.0)),
3 to DemoCaseClass(3, "3" to DemoCaseClass(3, 3.0)),
)
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}


xshould("handle Scala Option datasets") {
val caseClasses = listOf(Some(1), Some(2), Some(3))
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option Option datasets") {
val caseClasses = listOf(
Some(Some(1)),
Some(Some(2)),
Some(Some(3)),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle data class Scala Option datasets") {
val caseClasses = listOf(
Some(1) to Some(2),
Some(3) to Some(4),
Some(5) to Some(6),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option data class datasets") {
val caseClasses = listOf(
Some(1 to 2),
Some(3 to 4),
Some(5 to 6),
)
val dataset = caseClasses.toDS()
dataset.collectAsList() shouldBe caseClasses
}

should("collect data classes with doubles correctly") {
val ll1 = LonLat(1.0, 2.0)
val ll2 = LonLat(3.0, 4.0)
Expand Down