Skip to content

Commit b18f889

Browse files
Jolanrensenasm0dey
authored andcommitted
feat: more KeyValueGroupedDataset wrapper functions (#81)
1 parent 31eed52 commit b18f889

File tree

4 files changed

+288
-0
lines changed
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

4 files changed

+288
-0
lines changed

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ import org.apache.spark.sql.catalyst.JavaTypeInference
3131
import org.apache.spark.sql.catalyst.KotlinReflection
3232
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3333
import org.apache.spark.sql.catalyst.expressions.Expression
34+
import org.apache.spark.sql.streaming.GroupState
35+
import org.apache.spark.sql.streaming.GroupStateTimeout
36+
import org.apache.spark.sql.streaming.OutputMode
3437
import org.apache.spark.sql.types.*
3538
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
3639
import scala.collection.Seq
@@ -43,6 +46,7 @@ import java.time.Instant
4346
import java.time.LocalDate
4447
import java.util.concurrent.ConcurrentHashMap
4548
import kotlin.reflect.KClass
49+
import kotlin.reflect.KProperty
4650
import kotlin.reflect.KType
4751
import kotlin.reflect.full.findAnnotation
4852
import kotlin.reflect.full.isSubclassOf
@@ -179,6 +183,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
179183
reduceGroups(ReduceFunction(func))
180184
.map { t -> t._1 to t._2 }
181185

186+
inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
187+
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
188+
): Dataset<U> = flatMapGroups(
189+
FlatMapGroupsFunction(func),
190+
encoder<U>()
191+
)
192+
193+
fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null
194+
195+
operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
196+
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)
197+
198+
199+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
200+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
201+
): Dataset<U> = mapGroupsWithState(
202+
MapGroupsWithStateFunction(func),
203+
encoder<S>(),
204+
encoder<U>()
205+
)
206+
207+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
208+
timeoutConf: GroupStateTimeout,
209+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
210+
): Dataset<U> = mapGroupsWithState(
211+
MapGroupsWithStateFunction(func),
212+
encoder<S>(),
213+
encoder<U>(),
214+
timeoutConf
215+
)
216+
217+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
218+
outputMode: OutputMode,
219+
timeoutConf: GroupStateTimeout,
220+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
221+
): Dataset<U> = flatMapGroupsWithState(
222+
FlatMapGroupsWithStateFunction(func),
223+
outputMode,
224+
encoder<S>(),
225+
encoder<U>(),
226+
timeoutConf
227+
)
228+
229+
inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
230+
other: KeyValueGroupedDataset<K, U>,
231+
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
232+
): Dataset<R> = cogroup(
233+
other,
234+
CoGroupFunction(func),
235+
encoder<R>()
236+
)
237+
182238
inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
183239
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
184240
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect
2222
import ch.tutteli.atrium.verbs.expect
2323
import io.kotest.core.spec.style.ShouldSpec
2424
import io.kotest.matchers.shouldBe
25+
import org.apache.spark.sql.streaming.GroupState
26+
import org.apache.spark.sql.streaming.GroupStateTimeout
2527
import scala.collection.Seq
2628
import org.apache.spark.sql.Dataset
2729
import java.io.Serializable
@@ -202,6 +204,92 @@ class ApiTest : ShouldSpec({
202204
kotlinList.first() shouldBe "a"
203205
kotlinList.last() shouldBe "b"
204206
}
207+
should("perform flat map on grouped datasets") {
208+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
209+
.toDS()
210+
.groupByKey { it.first }
211+
212+
val flatMapped = groupedDataset.flatMapGroups { key, values ->
213+
val collected = values.asSequence().toList()
214+
215+
if (collected.size > 1) collected.iterator()
216+
else emptyList<Pair<Int, String>>().iterator()
217+
}
218+
219+
flatMapped.count() shouldBe 2
220+
}
221+
should("perform map group with state and timeout conf on grouped datasets") {
222+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
223+
.toDS()
224+
.groupByKey { it.first }
225+
226+
val mappedWithStateTimeoutConf =
227+
groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
228+
var s by state
229+
val collected = values.asSequence().toList()
230+
231+
s = key
232+
s shouldBe key
233+
234+
s!! to collected.map { it.second }
235+
}
236+
237+
mappedWithStateTimeoutConf.count() shouldBe 2
238+
}
239+
should("perform map group with state on grouped datasets") {
240+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
241+
.toDS()
242+
.groupByKey { it.first }
243+
244+
val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
245+
var s by state
246+
val collected = values.asSequence().toList()
247+
248+
s = key
249+
s shouldBe key
250+
251+
s!! to collected.map { it.second }
252+
}
253+
254+
mappedWithState.count() shouldBe 2
255+
}
256+
should("perform flat map group with state on grouped datasets") {
257+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
258+
.toDS()
259+
.groupByKey { it.first }
260+
261+
val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
262+
var s by state
263+
val collected = values.asSequence().toList()
264+
265+
s = key
266+
s shouldBe key
267+
268+
if (collected.size > 1) collected.iterator()
269+
else emptyList<Pair<Int, String>>().iterator()
270+
}
271+
272+
flatMappedWithState.count() shouldBe 2
273+
}
274+
should("be able to cogroup grouped datasets") {
275+
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
276+
.toDS()
277+
.groupByKey { it.first }
278+
279+
val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
280+
.toDS()
281+
.groupByKey { it.first }
282+
283+
val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
284+
listOf(
285+
key to (left.asSequence() + right.asSequence())
286+
.map { it.second }
287+
.toList()
288+
).iterator()
289+
}
290+
291+
cogrouped.count() shouldBe 4
292+
}
205293
should("be able to serialize Date 2.4") { // uses knownDataTypes
206294
val dataset: Dataset<Pair<Date, Int>> = dsOf(Date.valueOf("2020-02-10") to 5)
207295
dataset.show()

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ import org.apache.spark.broadcast.Broadcast
2828
import org.apache.spark.sql.*
2929
import org.apache.spark.sql.Encoders.*
3030
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
31+
import org.apache.spark.sql.streaming.GroupState
32+
import org.apache.spark.sql.streaming.GroupStateTimeout
33+
import org.apache.spark.sql.streaming.OutputMode
3134
import org.apache.spark.sql.types.*
3235
import org.jetbrains.kotinx.spark.extensions.KSparkExtensions
3336
import scala.reflect.ClassTag
@@ -39,6 +42,7 @@ import java.time.Instant
3942
import java.time.LocalDate
4043
import java.util.concurrent.ConcurrentHashMap
4144
import kotlin.reflect.KClass
45+
import kotlin.reflect.KProperty
4246
import kotlin.reflect.KType
4347
import kotlin.reflect.full.findAnnotation
4448
import kotlin.reflect.full.isSubclassOf
@@ -172,6 +176,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
172176
reduceGroups(ReduceFunction(func))
173177
.map { t -> t._1 to t._2 }
174178

179+
inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
180+
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
181+
): Dataset<U> = flatMapGroups(
182+
FlatMapGroupsFunction(func),
183+
encoder<U>()
184+
)
185+
186+
fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null
187+
188+
operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
189+
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)
190+
191+
192+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
193+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
194+
): Dataset<U> = mapGroupsWithState(
195+
MapGroupsWithStateFunction(func),
196+
encoder<S>(),
197+
encoder<U>()
198+
)
199+
200+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
201+
timeoutConf: GroupStateTimeout,
202+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
203+
): Dataset<U> = mapGroupsWithState(
204+
MapGroupsWithStateFunction(func),
205+
encoder<S>(),
206+
encoder<U>(),
207+
timeoutConf
208+
)
209+
210+
inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
211+
outputMode: OutputMode,
212+
timeoutConf: GroupStateTimeout,
213+
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
214+
): Dataset<U> = flatMapGroupsWithState(
215+
FlatMapGroupsWithStateFunction(func),
216+
outputMode,
217+
encoder<S>(),
218+
encoder<U>(),
219+
timeoutConf
220+
)
221+
222+
inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
223+
other: KeyValueGroupedDataset<K, U>,
224+
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
225+
): Dataset<R> = cogroup(
226+
other,
227+
CoGroupFunction(func),
228+
encoder<R>()
229+
)
230+
175231
inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
176232
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
177233
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect
2222
import ch.tutteli.atrium.verbs.expect
2323
import io.kotest.core.spec.style.ShouldSpec
2424
import io.kotest.matchers.shouldBe
25+
import org.apache.spark.sql.streaming.GroupState
26+
import org.apache.spark.sql.streaming.GroupStateTimeout
2527
import scala.collection.Seq
2628
import org.apache.spark.sql.Dataset
2729
import java.io.Serializable
@@ -216,6 +218,92 @@ class ApiTest : ShouldSpec({
216218
kotlinList.first() shouldBe "a"
217219
kotlinList.last() shouldBe "b"
218220
}
221+
should("perform flat map on grouped datasets") {
222+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
223+
.toDS()
224+
.groupByKey { it.first }
225+
226+
val flatMapped = groupedDataset.flatMapGroups { key, values ->
227+
val collected = values.asSequence().toList()
228+
229+
if (collected.size > 1) collected.iterator()
230+
else emptyList<Pair<Int, String>>().iterator()
231+
}
232+
233+
flatMapped.count() shouldBe 2
234+
}
235+
should("perform map group with state and timeout conf on grouped datasets") {
236+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
237+
.toDS()
238+
.groupByKey { it.first }
239+
240+
val mappedWithStateTimeoutConf =
241+
groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
242+
var s by state
243+
val collected = values.asSequence().toList()
244+
245+
s = key
246+
s shouldBe key
247+
248+
s!! to collected.map { it.second }
249+
}
250+
251+
mappedWithStateTimeoutConf.count() shouldBe 2
252+
}
253+
should("perform map group with state on grouped datasets") {
254+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
255+
.toDS()
256+
.groupByKey { it.first }
257+
258+
val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
259+
var s by state
260+
val collected = values.asSequence().toList()
261+
262+
s = key
263+
s shouldBe key
264+
265+
s!! to collected.map { it.second }
266+
}
267+
268+
mappedWithState.count() shouldBe 2
269+
}
270+
should("perform flat map group with state on grouped datasets") {
271+
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
272+
.toDS()
273+
.groupByKey { it.first }
274+
275+
val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
276+
var s by state
277+
val collected = values.asSequence().toList()
278+
279+
s = key
280+
s shouldBe key
281+
282+
if (collected.size > 1) collected.iterator()
283+
else emptyList<Pair<Int, String>>().iterator()
284+
}
285+
286+
flatMappedWithState.count() shouldBe 2
287+
}
288+
should("be able to cogroup grouped datasets") {
289+
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
290+
.toDS()
291+
.groupByKey { it.first }
292+
293+
val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
294+
.toDS()
295+
.groupByKey { it.first }
296+
297+
val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
298+
listOf(
299+
key to (left.asSequence() + right.asSequence())
300+
.map { it.second }
301+
.toList()
302+
).iterator()
303+
}
304+
305+
cogrouped.count() shouldBe 4
306+
}
219307
should("handle LocalDate Datasets") { // uses encoder
220308
val dataset: Dataset<LocalDate> = dsOf(LocalDate.now(), LocalDate.now())
221309
dataset.show()

0 commit comments

Comments
 (0)