diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 14945dbe..57d7d682 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -31,6 +31,9 @@ import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.KotlinReflection import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions import scala.collection.Seq @@ -43,6 +46,7 @@ import java.time.Instant import java.time.LocalDate import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.KClass +import kotlin.reflect.KProperty import kotlin.reflect.KType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubclassOf @@ -179,6 +183,58 @@ inline fun KeyValueGroupedDataset.reduc reduceGroups(ReduceFunction(func)) .map { t -> t._1 to t._2 } +inline fun KeyValueGroupedDataset.flatMapGroups( + noinline func: (key: K, values: Iterator) -> Iterator +): Dataset = flatMapGroups( + FlatMapGroupsFunction(func), + encoder() +) + +fun GroupState.getOrNull(): S? = if (exists()) get() else null + +operator fun GroupState.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull() +operator fun GroupState.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value) + + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder() +) + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.flatMapGroupsWithState( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> Iterator +): Dataset = flatMapGroupsWithState( + FlatMapGroupsWithStateFunction(func), + outputMode, + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.cogroup( + other: KeyValueGroupedDataset, + noinline func: (key: K, left: Iterator, right: Iterator) -> Iterator +): Dataset = cogroup( + other, + CoGroupFunction(func), + encoder() +) + inline fun Dataset.downcast(): Dataset = `as`(encoder()) inline fun Dataset<*>.`as`(): Dataset = `as`(encoder()) inline fun Dataset<*>.to(): Dataset = `as`(encoder()) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 3d6061f2..34e41482 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect import ch.tutteli.atrium.verbs.expect import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout import scala.collection.Seq import org.apache.spark.sql.Dataset import java.io.Serializable @@ -202,6 +204,92 @@ class ApiTest : ShouldSpec({ kotlinList.first() shouldBe "a" kotlinList.last() shouldBe "b" } + should("perform flat map on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMapped = groupedDataset.flatMapGroups { key, values -> + val collected = values.asSequence().toList() + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMapped.count() shouldBe 2 + } + should("perform map group with state and timeout conf on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val mappedWithStateTimeoutConf = + groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithStateTimeoutConf.count() shouldBe 2 + } + should("perform map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithState.count() shouldBe 2 + } + should("perform flat map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMappedWithState.count() shouldBe 2 + } + should("be able to cogroup grouped datasets") { + val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f") + .toDS() + .groupByKey { it.first } + + val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right -> + listOf( + key to (left.asSequence() + right.asSequence()) + .map { it.second } + .toList() + ).iterator() + } + + cogrouped.count() shouldBe 4 + } should("be able to serialize Date 2.4") { // uses knownDataTypes val dataset: Dataset> = dsOf(Date.valueOf("2020-02-10") to 5) dataset.show() diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 5eeccff3..d5adc3bd 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -28,6 +28,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.* import org.apache.spark.sql.Encoders.* import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotinx.spark.extensions.KSparkExtensions import scala.reflect.ClassTag @@ -39,6 +42,7 @@ import java.time.Instant import java.time.LocalDate import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.KClass +import kotlin.reflect.KProperty import kotlin.reflect.KType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubclassOf @@ -172,6 +176,58 @@ inline fun KeyValueGroupedDataset.reduc reduceGroups(ReduceFunction(func)) .map { t -> t._1 to t._2 } +inline fun KeyValueGroupedDataset.flatMapGroups( + noinline func: (key: K, values: Iterator) -> Iterator +): Dataset = flatMapGroups( + FlatMapGroupsFunction(func), + encoder() +) + +fun GroupState.getOrNull(): S? = if (exists()) get() else null + +operator fun GroupState.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull() +operator fun GroupState.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value) + + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder() +) + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.flatMapGroupsWithState( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> Iterator +): Dataset = flatMapGroupsWithState( + FlatMapGroupsWithStateFunction(func), + outputMode, + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.cogroup( + other: KeyValueGroupedDataset, + noinline func: (key: K, left: Iterator, right: Iterator) -> Iterator +): Dataset = cogroup( + other, + CoGroupFunction(func), + encoder() +) + inline fun Dataset.downcast(): Dataset = `as`(encoder()) inline fun Dataset<*>.`as`(): Dataset = `as`(encoder()) inline fun Dataset<*>.to(): Dataset = `as`(encoder()) diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 14fd8808..3522a68e 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect import ch.tutteli.atrium.verbs.expect import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout import scala.collection.Seq import org.apache.spark.sql.Dataset import java.io.Serializable @@ -216,6 +218,92 @@ class ApiTest : ShouldSpec({ kotlinList.first() shouldBe "a" kotlinList.last() shouldBe "b" } + should("perform flat map on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMapped = groupedDataset.flatMapGroups { key, values -> + val collected = values.asSequence().toList() + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMapped.count() shouldBe 2 + } + should("perform map group with state and timeout conf on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val mappedWithStateTimeoutConf = + groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithStateTimeoutConf.count() shouldBe 2 + } + should("perform map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithState.count() shouldBe 2 + } + should("perform flat map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMappedWithState.count() shouldBe 2 + } + should("be able to cogroup grouped datasets") { + val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f") + .toDS() + .groupByKey { it.first } + + val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right -> + listOf( + key to (left.asSequence() + right.asSequence()) + .map { it.second } + .toList() + ).iterator() + } + + cogrouped.count() shouldBe 4 + } should("handle LocalDate Datasets") { // uses encoder val dataset: Dataset = dsOf(LocalDate.now(), LocalDate.now()) dataset.show()