Skip to content

Commit 31eed52

Browse files
Jolanrensenasm0dey
authored andcommitted
ref: improves broadcasting
1 parent bd8e97b commit 31eed52

File tree

5 files changed

+70
-23
lines changed
  • examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

5 files changed

+70
-23
lines changed

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/Broadcasting.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import java.io.Serializable
2929
data class SomeClass(val a: IntArray, val b: Int) : Serializable
3030

3131
fun main() = withSpark {
32-
val broadcastVariable = spark.sparkContext.broadcast(SomeClass(a = intArrayOf(5, 6), b = 3))
32+
val broadcastVariable = spark.broadcast(SomeClass(a = intArrayOf(5, 6), b = 3))
3333
val result = listOf(1, 2, 3, 4, 5)
3434
.toDS()
3535
.map {

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.jetbrains.kotlinx.spark.api
2323

2424
import org.apache.spark.SparkContext
25+
import org.apache.spark.api.java.JavaSparkContext
2526
import org.apache.spark.api.java.function.*
2627
import org.apache.spark.broadcast.Broadcast
2728
import org.apache.spark.sql.*
@@ -66,14 +67,33 @@ val ENCODERS = mapOf<KClass<*>, Encoder<*>>(
6667

6768
/**
6869
* Broadcast a read-only variable to the cluster, returning a
69-
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
70+
* [org.apache.spark.broadcast.Broadcast] object for reading it in distributed functions.
7071
* The variable will be sent to each cluster only once.
7172
*
7273
* @param value value to broadcast to the Spark nodes
7374
* @return `Broadcast` object, a read-only variable cached on each machine
7475
*/
75-
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = broadcast(value, encoder<T>().clsTag())
76+
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
77+
sparkContext.broadcast(value, encoder<T>().clsTag())
78+
} catch (e: ClassNotFoundException) {
79+
JavaSparkContext(sparkContext).broadcast(value)
80+
}
7681

82+
/**
83+
* Broadcast a read-only variable to the cluster, returning a
84+
* [org.apache.spark.broadcast.Broadcast] object for reading it in distributed functions.
85+
* The variable will be sent to each cluster only once.
86+
*
87+
* @param value value to broadcast to the Spark nodes
88+
* @return `Broadcast` object, a read-only variable cached on each machine
89+
* @see broadcast
90+
*/
91+
@Deprecated("You can now use `spark.broadcast()` instead.", ReplaceWith("spark.broadcast(value)"), DeprecationLevel.WARNING)
92+
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
93+
broadcast(value, encoder<T>().clsTag())
94+
} catch (e: ClassNotFoundException) {
95+
JavaSparkContext(this).broadcast(value)
96+
}
7797

7898
/**
7999
* Utility method to create dataset from list

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,24 @@ class ApiTest : ShouldSpec({
148148
@OptIn(ExperimentalStdlibApi::class)
149149
should("broadcast variables") {
150150
val largeList = (1..15).map { SomeClass(a = (it..15).toList().toIntArray(), b = it) }
151-
val broadcast = spark.sparkContext.broadcast(largeList)
152-
153-
val result: List<Int> = listOf(1, 2, 3, 4, 5)
154-
.toDS()
155-
.mapPartitions { iterator ->
156-
val receivedBroadcast = broadcast.value
157-
buildList {
158-
iterator.forEach {
159-
this.add(it + receivedBroadcast[it].b)
160-
}
161-
}.iterator()
162-
}
163-
.collectAsList()
151+
val broadcast = spark.broadcast(largeList)
152+
val broadcast2 = spark.broadcast(arrayOf(doubleArrayOf(1.0, 2.0, 3.0, 4.0)))
153+
154+
val result: List<Double> = listOf(1, 2, 3, 4, 5)
155+
.toDS()
156+
.mapPartitions { iterator ->
157+
val receivedBroadcast = broadcast.value
158+
val receivedBroadcast2 = broadcast2.value
159+
160+
buildList {
161+
iterator.forEach {
162+
this.add(it + receivedBroadcast[it].b * receivedBroadcast2[0][0])
163+
}
164+
}.iterator()
165+
}
166+
.collectAsList()
164167

165-
expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
168+
expect(result).asExpect().contains.inOrder.only.values(3.0, 5.0, 7.0, 9.0, 11.0)
166169
}
167170
should("Handle JavaConversions in Kotlin") {
168171
// Test the iterator conversion

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.jetbrains.kotlinx.spark.api
2323

2424
import org.apache.spark.SparkContext
25+
import org.apache.spark.api.java.JavaSparkContext
2526
import org.apache.spark.api.java.function.*
2627
import org.apache.spark.broadcast.Broadcast
2728
import org.apache.spark.sql.*
@@ -65,13 +66,33 @@ val ENCODERS = mapOf<KClass<*>, Encoder<*>>(
6566

6667
/**
6768
* Broadcast a read-only variable to the cluster, returning a
68-
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
69+
* [org.apache.spark.broadcast.Broadcast] object for reading it in distributed functions.
6970
* The variable will be sent to each cluster only once.
7071
*
7172
* @param value value to broadcast to the Spark nodes
7273
* @return `Broadcast` object, a read-only variable cached on each machine
7374
*/
74-
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = broadcast(value, encoder<T>().clsTag())
75+
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
76+
sparkContext.broadcast(value, encoder<T>().clsTag())
77+
} catch (e: ClassNotFoundException) {
78+
JavaSparkContext(sparkContext).broadcast(value)
79+
}
80+
81+
/**
82+
* Broadcast a read-only variable to the cluster, returning a
83+
* [org.apache.spark.broadcast.Broadcast] object for reading it in distributed functions.
84+
* The variable will be sent to each cluster only once.
85+
*
86+
* @param value value to broadcast to the Spark nodes
87+
* @return `Broadcast` object, a read-only variable cached on each machine
88+
* @see broadcast
89+
*/
90+
@Deprecated("You can now use `spark.broadcast()` instead.", ReplaceWith("spark.broadcast(value)"), DeprecationLevel.WARNING)
91+
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
92+
broadcast(value, encoder<T>().clsTag())
93+
} catch (e: ClassNotFoundException) {
94+
JavaSparkContext(this).broadcast(value)
95+
}
7596

7697
/**
7798
* Utility method to create dataset from list

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,24 @@ class ApiTest : ShouldSpec({
162162
@OptIn(ExperimentalStdlibApi::class)
163163
should("broadcast variables") {
164164
val largeList = (1..15).map { SomeClass(a = (it..15).toList().toIntArray(), b = it) }
165-
val broadcast = spark.sparkContext.broadcast(largeList)
165+
val broadcast = spark.broadcast(largeList)
166+
val broadcast2 = spark.broadcast(arrayOf(doubleArrayOf(1.0, 2.0, 3.0, 4.0)))
166167

167-
val result: List<Int> = listOf(1, 2, 3, 4, 5)
168+
val result: List<Double> = listOf(1, 2, 3, 4, 5)
168169
.toDS()
169170
.mapPartitions { iterator ->
170171
val receivedBroadcast = broadcast.value
172+
val receivedBroadcast2 = broadcast2.value
173+
171174
buildList {
172175
iterator.forEach {
173-
this.add(it + receivedBroadcast[it].b)
176+
this.add(it + receivedBroadcast[it].b * receivedBroadcast2[0][0])
174177
}
175178
}.iterator()
176179
}
177180
.collectAsList()
178181

179-
expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
182+
expect(result).asExpect().contains.inOrder.only.values(3.0, 5.0, 7.0, 9.0, 11.0)
180183
}
181184
should("Handle JavaConversions in Kotlin") {
182185
// Test the iterator conversion

0 commit comments

Comments
 (0)