Skip to content

Commit 36d2253

Browse files
authored
feat: implements selectTyped functions
fixes#85
1 parent 1cfe42b commit 36d2253

File tree

4 files changed

+171
-3
lines changed
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

4 files changed

+171
-3
lines changed

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,49 @@ inline operator fun <reified T, reified U> Dataset<T>.invoke(column: KProperty1<
724724
*/
725725
fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply { show(numRows, truncate) }
726726

727+
/**
728+
* Returns a new Dataset by computing the given [Column] expressions for each element.
729+
*/
730+
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
731+
c1: TypedColumn<T, U1>,
732+
c2: TypedColumn<T, U2>,
733+
): Dataset<Pair<U1, U2>> =
734+
select(c1, c2).map { Pair(it._1(), it._2()) }
735+
736+
/**
737+
* Returns a new Dataset by computing the given [Column] expressions for each element.
738+
*/
739+
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
740+
c1: TypedColumn<T, U1>,
741+
c2: TypedColumn<T, U2>,
742+
c3: TypedColumn<T, U3>,
743+
): Dataset<Triple<U1, U2, U3>> =
744+
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
745+
746+
/**
747+
* Returns a new Dataset by computing the given [Column] expressions for each element.
748+
*/
749+
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
750+
c1: TypedColumn<T, U1>,
751+
c2: TypedColumn<T, U2>,
752+
c3: TypedColumn<T, U3>,
753+
c4: TypedColumn<T, U4>,
754+
): Dataset<Arity4<U1, U2, U3, U4>> =
755+
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
756+
757+
/**
758+
* Returns a new Dataset by computing the given [Column] expressions for each element.
759+
*/
760+
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
761+
c1: TypedColumn<T, U1>,
762+
c2: TypedColumn<T, U2>,
763+
c3: TypedColumn<T, U3>,
764+
c4: TypedColumn<T, U4>,
765+
c5: TypedColumn<T, U5>,
766+
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
767+
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
768+
769+
727770
@OptIn(ExperimentalStdlibApi::class)
728771
inline fun <reified T> schema(map: Map<String, KType> = mapOf()) = schema(typeOf<T>(), map)
729772

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.streaming.GroupState
2626
import org.apache.spark.sql.streaming.GroupStateTimeout
2727
import scala.collection.Seq
2828
import org.apache.spark.sql.Dataset
29+
import org.apache.spark.sql.TypedColumn
2930
import org.apache.spark.sql.functions.*
3031
import scala.Product
3132
import scala.Tuple1
@@ -35,6 +36,7 @@ import java.io.Serializable
3536
import java.sql.Date
3637
import java.sql.Timestamp
3738
import java.time.LocalDate
39+
import kotlin.reflect.KProperty1
3840
import scala.collection.Iterator as ScalaIterator
3941
import scala.collection.Map as ScalaMap
4042
import scala.collection.mutable.Map as ScalaMutableMap
@@ -326,6 +328,46 @@ class ApiTest : ShouldSpec({
326328
val asList = dataset.takeAsList(2)
327329
asList.first().tuple shouldBe Tuple3(5L, "test", Tuple1(""))
328330
}
331+
@Suppress("UNCHECKED_CAST")
332+
should("support dataset select") {
333+
val dataset = dsOf(
334+
SomeClass(intArrayOf(1, 2, 3), 3),
335+
SomeClass(intArrayOf(1, 2, 4), 5),
336+
)
337+
338+
val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
339+
340+
val newDS2 = dataset.selectTyped(
341+
// col(SomeClass::a), NOTE that this doesn't work on 2.4, returnting a data class with an array in it
342+
col(SomeClass::b),
343+
col(SomeClass::b),
344+
)
345+
newDS2.show()
346+
347+
val newDS3 = dataset.selectTyped(
348+
col(SomeClass::b),
349+
col(SomeClass::b),
350+
col(SomeClass::b),
351+
)
352+
newDS3.show()
353+
354+
val newDS4 = dataset.selectTyped(
355+
col(SomeClass::b),
356+
col(SomeClass::b),
357+
col(SomeClass::b),
358+
col(SomeClass::b),
359+
)
360+
newDS4.show()
361+
362+
val newDS5 = dataset.selectTyped(
363+
col(SomeClass::b),
364+
col(SomeClass::b),
365+
col(SomeClass::b),
366+
col(SomeClass::b),
367+
col(SomeClass::b),
368+
)
369+
newDS5.show()
370+
}
329371
should("Access columns using invoke on datasets") {
330372
val dataset = dsOf(
331373
SomeClass(intArrayOf(1, 2, 3), 4),
@@ -399,6 +441,7 @@ class ApiTest : ShouldSpec({
399441
}
400442
})
401443

444+
402445
data class DataClassWithTuple<T : Product>(val tuple: T)
403446

404447

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,49 @@ inline operator fun <reified T, reified U> Dataset<T>.invoke(column: KProperty1<
720720
*/
721721
fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply { show(numRows, truncate) }
722722

723+
/**
724+
* Returns a new Dataset by computing the given [Column] expressions for each element.
725+
*/
726+
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
727+
c1: TypedColumn<T, U1>,
728+
c2: TypedColumn<T, U2>,
729+
): Dataset<Pair<U1, U2>> =
730+
select(c1, c2).map { Pair(it._1(), it._2()) }
731+
732+
/**
733+
* Returns a new Dataset by computing the given [Column] expressions for each element.
734+
*/
735+
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
736+
c1: TypedColumn<T, U1>,
737+
c2: TypedColumn<T, U2>,
738+
c3: TypedColumn<T, U3>,
739+
): Dataset<Triple<U1, U2, U3>> =
740+
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
741+
742+
/**
743+
* Returns a new Dataset by computing the given [Column] expressions for each element.
744+
*/
745+
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
746+
c1: TypedColumn<T, U1>,
747+
c2: TypedColumn<T, U2>,
748+
c3: TypedColumn<T, U3>,
749+
c4: TypedColumn<T, U4>,
750+
): Dataset<Arity4<U1, U2, U3, U4>> =
751+
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
752+
753+
/**
754+
* Returns a new Dataset by computing the given [Column] expressions for each element.
755+
*/
756+
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
757+
c1: TypedColumn<T, U1>,
758+
c2: TypedColumn<T, U2>,
759+
c3: TypedColumn<T, U3>,
760+
c4: TypedColumn<T, U4>,
761+
c5: TypedColumn<T, U5>,
762+
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
763+
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
764+
765+
723766
@OptIn(ExperimentalStdlibApi::class)
724767
fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
725768
val primitiveSchema = knownDataTypes[type.classifier]

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect
2222
import ch.tutteli.atrium.verbs.expect
2323
import io.kotest.core.spec.style.ShouldSpec
2424
import io.kotest.matchers.shouldBe
25+
import org.apache.spark.sql.streaming.GroupState
26+
import org.apache.spark.sql.streaming.GroupStateTimeout
27+
import scala.Product
2528
import scala.Tuple1
2629
import scala.Tuple2
2730
import scala.Tuple3
28-
import org.apache.spark.sql.streaming.GroupState
29-
import org.apache.spark.sql.streaming.GroupStateTimeout
3031
import scala.collection.Seq
3132
import org.apache.spark.sql.Dataset
3233
import org.apache.spark.sql.TypedColumn
3334
import org.apache.spark.sql.functions.*
34-
import scala.Product
3535
import java.io.Serializable
3636
import java.sql.Date
3737
import java.sql.Timestamp
@@ -350,6 +350,45 @@ class ApiTest : ShouldSpec({
350350
val asList = dataset.takeAsList(2)
351351
asList.first().tuple shouldBe Tuple3(5L, "test", Tuple1(""))
352352
}
353+
@Suppress("UNCHECKED_CAST")
354+
should("support dataset select") {
355+
val dataset = dsOf(
356+
SomeClass(intArrayOf(1, 2, 3), 3),
357+
SomeClass(intArrayOf(1, 2, 4), 5),
358+
)
359+
360+
val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
361+
362+
val newDS2 = dataset.selectTyped(
363+
col(SomeClass::a), // NOTE: this only works on 3.0, returning a data class with an array in it
364+
col(SomeClass::b),
365+
)
366+
newDS2.show()
367+
368+
val newDS3 = dataset.selectTyped(
369+
col(SomeClass::a),
370+
col(SomeClass::b),
371+
col(SomeClass::b),
372+
)
373+
newDS3.show()
374+
375+
val newDS4 = dataset.selectTyped(
376+
col(SomeClass::a),
377+
col(SomeClass::b),
378+
col(SomeClass::b),
379+
col(SomeClass::b),
380+
)
381+
newDS4.show()
382+
383+
val newDS5 = dataset.selectTyped(
384+
col(SomeClass::a),
385+
col(SomeClass::b),
386+
col(SomeClass::b),
387+
col(SomeClass::b),
388+
col(SomeClass::b),
389+
)
390+
newDS5.show()
391+
}
353392
should("Access columns using invoke on datasets") {
354393
val dataset = dsOf(
355394
SomeClass(intArrayOf(1, 2, 3), 4),

0 commit comments

Comments
 (0)