From a3093037834596a637b4b9647c73a5371f4c87dc Mon Sep 17 00:00:00 2001 From: Jolanrensen Date: Fri, 26 Feb 2021 00:57:10 +0100 Subject: [PATCH 1/2] added missing function wrappers for KeyValueGroupedDataset and cogroup including some helper functions for GroupState. added tests --- .../org/jetbrains/kotlinx/spark/api/ApiV1.kt | 56 ++++++++++++++ .../jetbrains/kotlinx/spark/api/ApiTest.kt | 73 +++++++++++++++++++ .../org/jetbrains/kotlinx/spark/api/ApiV1.kt | 56 ++++++++++++++ .../jetbrains/kotlinx/spark/api/ApiTest.kt | 73 +++++++++++++++++++ 4 files changed, 258 insertions(+) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 4dfa0e02..2fba11d8 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -30,6 +30,9 @@ import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.catalyst.KotlinReflection import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions import scala.collection.Seq @@ -42,6 +45,7 @@ import java.time.Instant import java.time.LocalDate import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.KClass +import kotlin.reflect.KProperty import kotlin.reflect.KType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubclassOf @@ -159,6 +163,58 @@ inline fun KeyValueGroupedDataset.reduc reduceGroups(ReduceFunction(func)) .map { t -> t._1 to t._2 } +inline fun KeyValueGroupedDataset.flatMapGroups( + noinline func: (key: K, values: Iterator) -> Iterator +): Dataset = flatMapGroups( + FlatMapGroupsFunction(func), + encoder() +) + +fun GroupState.getOrNull(): S? = if (exists()) get() else null + +operator fun GroupState.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull() +operator fun GroupState.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value) + + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder() +) + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.flatMapGroupsWithState( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> Iterator +): Dataset = flatMapGroupsWithState( + FlatMapGroupsWithStateFunction(func), + outputMode, + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.cogroup( + other: KeyValueGroupedDataset, + noinline func: (key: K, left: Iterator, right: Iterator) -> Iterator +): Dataset = cogroup( + other, + CoGroupFunction(func), + encoder() +) + inline fun Dataset.downcast(): Dataset = `as`(encoder()) inline fun Dataset<*>.`as`(): Dataset = `as`(encoder()) inline fun Dataset<*>.to(): Dataset = `as`(encoder()) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 74c30637..5161c38a 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -21,6 +21,9 @@ import ch.tutteli.atrium.api.fluent.en_GB.* import ch.tutteli.atrium.domain.builders.migration.asExpect import ch.tutteli.atrium.verbs.expect import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout import java.io.Serializable import java.time.LocalDate @@ -156,6 +159,76 @@ class ApiTest : ShouldSpec({ expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11) } + should("perform operations on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMapped = groupedDataset.flatMapGroups { key, values -> + val collected = values.asSequence().toList() + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMapped.count() shouldBe 2 + + val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithStateTimeoutConf.count() shouldBe 2 + + val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithState.count() shouldBe 2 + + val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMappedWithState.count() shouldBe 2 + } + should("be able to cogroup grouped datasets") { + val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f") + .toDS() + .groupByKey { it.first } + + val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right -> + listOf( + key to (left.asSequence() + right.asSequence()) + .map { it.second } + .toList() + ).iterator() + } + + cogrouped.count() shouldBe 4 + } } } }) diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 63c6bd54..d4a19599 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -27,6 +27,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.* import org.apache.spark.sql.Encoders.* import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotinx.spark.extensions.KSparkExtensions import scala.reflect.ClassTag @@ -38,6 +41,7 @@ import java.time.Instant import java.time.LocalDate import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.KClass +import kotlin.reflect.KProperty import kotlin.reflect.KType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubclassOf @@ -149,6 +153,58 @@ inline fun KeyValueGroupedDataset.reduc reduceGroups(ReduceFunction(func)) .map { t -> t._1 to t._2 } +inline fun KeyValueGroupedDataset.flatMapGroups( + noinline func: (key: K, values: Iterator) -> Iterator +): Dataset = flatMapGroups( + FlatMapGroupsFunction(func), + encoder() +) + +fun GroupState.getOrNull(): S? = if (exists()) get() else null + +operator fun GroupState.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull() +operator fun GroupState.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value) + + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder() +) + +inline fun KeyValueGroupedDataset.mapGroupsWithState( + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> U +): Dataset = mapGroupsWithState( + MapGroupsWithStateFunction(func), + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.flatMapGroupsWithState( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + noinline func: (key: K, values: Iterator, state: GroupState) -> Iterator +): Dataset = flatMapGroupsWithState( + FlatMapGroupsWithStateFunction(func), + outputMode, + encoder(), + encoder(), + timeoutConf +) + +inline fun KeyValueGroupedDataset.cogroup( + other: KeyValueGroupedDataset, + noinline func: (key: K, left: Iterator, right: Iterator) -> Iterator +): Dataset = cogroup( + other, + CoGroupFunction(func), + encoder() +) + inline fun Dataset.downcast(): Dataset = `as`(encoder()) inline fun Dataset<*>.`as`(): Dataset = `as`(encoder()) inline fun Dataset<*>.to(): Dataset = `as`(encoder()) diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index f312dd7d..c643ad47 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -21,6 +21,9 @@ import ch.tutteli.atrium.api.fluent.en_GB.* import ch.tutteli.atrium.domain.builders.migration.asExpect import ch.tutteli.atrium.verbs.expect import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import org.apache.spark.sql.streaming.GroupState +import org.apache.spark.sql.streaming.GroupStateTimeout import java.io.Serializable import java.time.LocalDate @@ -169,6 +172,76 @@ class ApiTest : ShouldSpec({ expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11) } + should("perform operations on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val flatMapped = groupedDataset.flatMapGroups { key, values -> + val collected = values.asSequence().toList() + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMapped.count() shouldBe 2 + + val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithStateTimeoutConf.count() shouldBe 2 + + val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + s!! to collected.map { it.second } + } + + mappedWithState.count() shouldBe 2 + + val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() + + s = key + s shouldBe key + + if (collected.size > 1) collected.iterator() + else emptyList>().iterator() + } + + flatMappedWithState.count() shouldBe 2 + } + should("be able to cogroup grouped datasets") { + val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } + + val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f") + .toDS() + .groupByKey { it.first } + + val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right -> + listOf( + key to (left.asSequence() + right.asSequence()) + .map { it.second } + .toList() + ).iterator() + } + + cogrouped.count() shouldBe 4 + } } } }) From 7db51eec65e06168029b67df20644a07caa0fc31 Mon Sep 17 00:00:00 2001 From: Jolanrensen Date: Mon, 26 Apr 2021 17:32:48 +0200 Subject: [PATCH 2/2] split up the tests into separate ones --- .../jetbrains/kotlinx/spark/api/ApiTest.kt | 32 ++++++++++++++----- .../jetbrains/kotlinx/spark/api/ApiTest.kt | 32 ++++++++++++++----- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 3d2969ec..42ac3d0f 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -162,7 +162,7 @@ class ApiTest : ShouldSpec({ expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11) } - should("perform operations on grouped datasets") { + should("perform flat map on grouped datasets") { val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") .toDS() .groupByKey { it.first } @@ -175,18 +175,29 @@ class ApiTest : ShouldSpec({ } flatMapped.count() shouldBe 2 + } + should("perform map group with state and timeout conf on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } - val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> - var s by state - val collected = values.asSequence().toList() + val mappedWithStateTimeoutConf = + groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() - s = key - s shouldBe key + s = key + s shouldBe key - s!! to collected.map { it.second } - } + s!! to collected.map { it.second } + } mappedWithStateTimeoutConf.count() shouldBe 2 + } + should("perform map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> var s by state @@ -199,6 +210,11 @@ class ApiTest : ShouldSpec({ } mappedWithState.count() shouldBe 2 + } + should("perform flat map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> var s by state diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index b6dfaf58..f2f76c0e 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -176,7 +176,7 @@ class ApiTest : ShouldSpec({ expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11) } - should("perform operations on grouped datasets") { + should("perform flat map on grouped datasets") { val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") .toDS() .groupByKey { it.first } @@ -189,18 +189,29 @@ class ApiTest : ShouldSpec({ } flatMapped.count() shouldBe 2 + } + should("perform map group with state and timeout conf on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } - val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> - var s by state - val collected = values.asSequence().toList() + val mappedWithStateTimeoutConf = + groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState -> + var s by state + val collected = values.asSequence().toList() - s = key - s shouldBe key + s = key + s shouldBe key - s!! to collected.map { it.second } - } + s!! to collected.map { it.second } + } mappedWithStateTimeoutConf.count() shouldBe 2 + } + should("perform map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> var s by state @@ -213,6 +224,11 @@ class ApiTest : ShouldSpec({ } mappedWithState.count() shouldBe 2 + } + should("perform flat map group with state on grouped datasets") { + val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c") + .toDS() + .groupByKey { it.first } val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState -> var s by state