Skip to content

Feature: more KeyValueGroupedDataset wrapper functions #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import org.apache.spark.sql.catalyst.JavaTypeInference
import org.apache.spark.sql.catalyst.KotlinReflection
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import scala.collection.Seq
Expand All @@ -43,6 +46,7 @@ import java.time.Instant
import java.time.LocalDate
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.isSubclassOf
Expand Down Expand Up @@ -179,6 +183,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
reduceGroups(ReduceFunction(func))
.map { t -> t._1 to t._2 }

inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
): Dataset<U> = flatMapGroups(
FlatMapGroupsFunction(func),
encoder<U>()
)

fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null

operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)


inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>()
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
): Dataset<U> = flatMapGroupsWithState(
FlatMapGroupsWithStateFunction(func),
outputMode,
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
other: KeyValueGroupedDataset<K, U>,
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
): Dataset<R> = cogroup(
other,
CoGroupFunction(func),
encoder<R>()
)

inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect
import ch.tutteli.atrium.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.shouldBe
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import scala.collection.Seq
import org.apache.spark.sql.Dataset
import java.io.Serializable
Expand Down Expand Up @@ -202,6 +204,92 @@ class ApiTest : ShouldSpec({
kotlinList.first() shouldBe "a"
kotlinList.last() shouldBe "b"
}
should("perform flat map on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMapped = groupedDataset.flatMapGroups { key, values ->
val collected = values.asSequence().toList()

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMapped.count() shouldBe 2
}
should("perform map group with state and timeout conf on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val mappedWithStateTimeoutConf =
groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithStateTimeoutConf.count() shouldBe 2
}
should("perform map group with state on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithState.count() shouldBe 2
}
should("perform flat map group with state on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be moved to separate tests?

var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMappedWithState.count() shouldBe 2
}
should("be able to cogroup grouped datasets") {
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
.toDS()
.groupByKey { it.first }

val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
listOf(
key to (left.asSequence() + right.asSequence())
.map { it.second }
.toList()
).iterator()
}

cogrouped.count() shouldBe 4
}
should("be able to serialize Date 2.4") { // uses knownDataTypes
val dataset: Dataset<Pair<Date, Int>> = dsOf(Date.valueOf("2020-02-10") to 5)
dataset.show()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.*
import org.apache.spark.sql.Encoders.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
import org.jetbrains.kotinx.spark.extensions.KSparkExtensions
import scala.reflect.ClassTag
Expand All @@ -39,6 +42,7 @@ import java.time.Instant
import java.time.LocalDate
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.isSubclassOf
Expand Down Expand Up @@ -172,6 +176,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
reduceGroups(ReduceFunction(func))
.map { t -> t._1 to t._2 }

inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
): Dataset<U> = flatMapGroups(
FlatMapGroupsFunction(func),
encoder<U>()
)

fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null

operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)


inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>()
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
): Dataset<U> = flatMapGroupsWithState(
FlatMapGroupsWithStateFunction(func),
outputMode,
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
other: KeyValueGroupedDataset<K, U>,
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
): Dataset<R> = cogroup(
other,
CoGroupFunction(func),
encoder<R>()
)

inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect
import ch.tutteli.atrium.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.shouldBe
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import scala.collection.Seq
import org.apache.spark.sql.Dataset
import java.io.Serializable
Expand Down Expand Up @@ -216,6 +218,92 @@ class ApiTest : ShouldSpec({
kotlinList.first() shouldBe "a"
kotlinList.last() shouldBe "b"
}
should("perform flat map on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMapped = groupedDataset.flatMapGroups { key, values ->
val collected = values.asSequence().toList()

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMapped.count() shouldBe 2
}
should("perform map group with state and timeout conf on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val mappedWithStateTimeoutConf =
groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithStateTimeoutConf.count() shouldBe 2
}
should("perform map group with state on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithState.count() shouldBe 2
}
should("perform flat map group with state on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMappedWithState.count() shouldBe 2
}
should("be able to cogroup grouped datasets") {
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
.toDS()
.groupByKey { it.first }

val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
listOf(
key to (left.asSequence() + right.asSequence())
.map { it.second }
.toList()
).iterator()
}

cogrouped.count() shouldBe 4
}
should("handle LocalDate Datasets") { // uses encoder
val dataset: Dataset<LocalDate> = dsOf(LocalDate.now(), LocalDate.now())
dataset.show()
Expand Down