Skip to content

JavaRDD extension functions + iterators #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
-Pspark=${{ matrix.spark }}
-Pscala=${{ matrix.scala }}
clean
build
test
--scan

# qodana:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package org.jetbrains.kotlinx.spark.extensions

import org.apache.spark.SparkContext
import org.apache.spark.sql._

import java.util
import scala.reflect.ClassTag

object KSparkExtensions {

Expand Down Expand Up @@ -58,4 +58,17 @@ object KSparkExtensions {
}

def sparkContext(s: SparkSession): SparkContext = s.sparkContext

/**
* Produces a ClassTag[T], which is actually just a casted ClassTag[AnyRef].
*
* This method is used to keep ClassTags out of the external Java API, as the Java compiler
* cannot produce them automatically. While this ClassTag-faking does please the compiler,
* it can cause problems at runtime if the Scala API relies on ClassTags for correctness.
*
* Often, though, a ClassTag[AnyRef] will not lead to incorrect behavior, just worse performance
* or security issues. For instance, an Array[AnyRef] can hold any type T, but may lose primitive
* specialization.
*/
def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
{
"cell_type": "markdown",
"source": [
"By default the latest version of the API and the latest supported Spark version is chosen.\n",
"To specify your own: `%use spark(spark=3.2, v=1.1.0)`"
"By default, the latest version of the API and the latest supported Spark version is chosen.\n",
"To specify your own: `%use spark(spark=3.3.0, scala=2.13, v=1.2.0)`"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -35,6 +35,18 @@
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
Expand Down Expand Up @@ -312,14 +324,13 @@
}
],
"source": [
"val rdd: JavaRDD<Tuple2<Int, String>> = sc.parallelize(\n",
" listOf(\n",
" 1 X \"aaa\",\n",
" t(2, \"bbb\"),\n",
" tupleOf(3, \"ccc\"),\n",
" )\n",
"val rdd: JavaRDD<Tuple2<Int, String>> = rddOf(\n",
" 1 X \"aaa\",\n",
" t(2, \"bbb\"),\n",
" tupleOf(3, \"ccc\"),\n",
")\n",
"\n",
"\n",
"rdd"
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package org.jetbrains.kotlinx.spark.examples

import org.apache.spark.sql.Dataset
import org.jetbrains.kotlinx.spark.api.*
import org.jetbrains.kotlinx.spark.api.tuples.X
import org.jetbrains.kotlinx.spark.examples.GroupCalculation.getAllPossibleGroups
import scala.Tuple2
import kotlin.math.pow

/**
* Gets all the possible, unique, non repeating groups of indices for a list.
*
* Example by Jolanrensen.
*/

fun main() = withSpark {
val groupIndices = getAllPossibleGroups(listSize = 10, groupSize = 4)
.sort("value")

groupIndices.showDS(numRows = groupIndices.count().toInt())
}

object GroupCalculation {

/**
* Get all the possible, unique, non repeating groups (of size [groupSize]) of indices for a list of
* size [listSize].
*
*
* The workload is evenly distributed by [listSize] and [groupSize]
*
* @param listSize the size of the list for which to calculate the indices
* @param groupSize the size of a group of indices
* @return all the possible, unique non repeating groups of indices
*/
fun KSparkSession.getAllPossibleGroups(
listSize: Int,
groupSize: Int,
): Dataset<IntArray> {
val indices = (0 until listSize).toList().toRDD() // Easy RDD creation!

// for a groupSize of 1, no pairing up is needed, so just return the indices converted to IntArrays
if (groupSize == 1) {
return indices
.mapPartitions {
it.map { intArrayOf(it) }
}
.toDS()
}

// this converts all indices to (number in table, index)
val keys = indices.mapPartitions {

// _1 is key (item in table), _2 is index in list
it.transformAsSequence {
flatMap { listIndex ->

// for each dimension loop over the other dimensions using addTuples
(0 until groupSize).asSequence().flatMap { dimension ->
addTuples(
groupSize = groupSize,
value = listIndex,
listSize = listSize,
skipDimension = dimension,
)
}
}
}
}

// Since we have a JavaRDD<Tuple2> we can aggregateByKey!
// Each number in table occurs for each dimension as key.
// The values of those two will be a tuple of (key, indices as list)
val allPossibleGroups = keys.aggregateByKey(
zeroValue = IntArray(groupSize) { -1 },
seqFunc = { base: IntArray, listIndex: Int ->
// put listIndex in the first empty spot in base
base[base.indexOfFirst { it < 0 }] = listIndex

base
},

// how to merge partially filled up int arrays
combFunc = { a: IntArray, b: IntArray ->
// merge a and b
var j = 0
for (i in a.indices) {
if (a[i] < 0) {
while (b[j] < 0) {
j++
if (j == b.size) return@aggregateByKey a
}
a[i] = b[j]
j++
}
}
a
},
)
.values() // finally just take the values

return allPossibleGroups.toDS()
}

/**
* Simple method to give each index of x dimensions a unique number.
*
* @param indexTuple IntArray (can be seen as Tuple) of size x with all values < listSize. The index for which to return the number
* @param listSize The size of the list, aka the max width, height etc. of the table
* @return the unique number for this [indexTuple]
*/
private fun getTupleValue(indexTuple: List<Int>, listSize: Int): Int =
indexTuple.indices.sumOf {
indexTuple[it] * listSize.toDouble().pow(it).toInt()
}


/**
* To make sure that every tuple is only picked once, this method returns true only if the indices are in the right
* corner of the matrix. This works for any number of dimensions > 1. Here is an example for 2-D:
*
*
* - 0 1 2 3 4 5 6 7 8 9
* --------------------------------
* 0| x ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓
* 1| x x ✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓
* 2| x x x ✓ ✓ ✓ ✓ ✓ ✓ ✓
* 3| x x x x ✓ ✓ ✓ ✓ ✓ ✓
* 4| x x x x x ✓ ✓ ✓ ✓ ✓
* 5| x x x x x x ✓ ✓ ✓ ✓
* 6| x x x x x x x ✓ ✓ ✓
* 7| x x x x x x x x ✓ ✓
* 8| x x x x x x x x x ✓
* 9| x x x x x x x x x x
*
* @param indexTuple a tuple of indices in the form of an IntArray
* @return true if this tuple is in the right corner and should be included
*/
private fun isValidIndexTuple(indexTuple: List<Int>): Boolean {
// x - y > 0; 2d
// (x - y) > 0 && (x - z) > 0 && (y - z) > 0; 3d
// (x - y) > 0 && (x - z) > 0 && (x - a) > 0 && (y - z) > 0 && (y - a) > 0 && (z - a) > 0; 4d
require(indexTuple.size >= 2) { "not a tuple" }
for (i in 0 until indexTuple.size - 1) {
for (j in i + 1 until indexTuple.size) {
if (indexTuple[i] - indexTuple[j] <= 0) return false
}
}
return true
}

/**
* Recursive method that for [skipDimension] loops over all the other dimensions and returns all results from
* [getTupleValue] as key and [value] as value.
* In the end, the return value will have, for each key in the table below, a value for the key's column, row etc.
*
*
* This is an example for 2D. The letters will be int indices as well (a = 0, b = 1, ..., [listSize]), but help for clarification.
* The numbers we don't want are filtered out using [isValidIndexTuple].
* The actual value of the number in the table comes from [getTupleValue].
*
*
*
*
* - a b c d e f g h i j
* --------------------------------
* a| - 1 2 3 4 5 6 7 8 9
* b| - - 12 13 14 15 16 17 18 19
* c| - - - 23 24 25 26 27 28 29
* d| - - - - 34 35 36 37 38 39
* e| - - - - - 45 46 47 48 49
* f| - - - - - - 56 57 58 59
* g| - - - - - - - 67 68 69
* h| - - - - - - - - 78 79
* i| - - - - - - - - - 89
* j| - - - - - - - - - -
*
*
* @param groupSize the size of index tuples to form
* @param value the current index to work from (can be seen as a letter in the table above)
* @param listSize the size of the list to make
* @param skipDimension the current dimension that will have a set value [value] while looping over the other dimensions
*/
private fun addTuples(
groupSize: Int,
value: Int,
listSize: Int,
skipDimension: Int,
): List<Tuple2<Int, Int>> {

/**
* @param currentDimension the indicator for which dimension we're currently calculating for (and how deep in the recursion we are)
* @param indexTuple the list (or tuple) in which to store the current indices
*/
fun recursiveCall(
currentDimension: Int = 0,
indexTuple: List<Int> = emptyList(),
): List<Tuple2<Int, Int>> = when {
// base case
currentDimension >= groupSize ->
if (isValidIndexTuple(indexTuple))
listOf(getTupleValue(indexTuple, listSize) X value)
else
emptyList()

currentDimension == skipDimension ->
recursiveCall(
currentDimension = currentDimension + 1,
indexTuple = indexTuple + value,
)

else ->
(0 until listSize).flatMap { i ->
recursiveCall(
currentDimension = currentDimension + 1,
indexTuple = indexTuple + i,
)
}
}

return recursiveCall()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@ import org.apache.spark.api.java.StorageLevels
import org.apache.spark.streaming.Durations
import org.apache.spark.streaming.State
import org.apache.spark.streaming.StateSpec
import org.jetbrains.kotlinx.spark.api.getOrElse
import org.jetbrains.kotlinx.spark.api.mapWithState
import org.jetbrains.kotlinx.spark.api.toPairRDD
import org.jetbrains.kotlinx.spark.api.*
import org.jetbrains.kotlinx.spark.api.tuples.X
import org.jetbrains.kotlinx.spark.api.withSparkStreaming
import java.util.regex.Pattern
import kotlin.system.exitProcess

Expand Down Expand Up @@ -71,8 +68,8 @@ object KotlinStatefulNetworkCount {
) {

// Initial state RDD input to mapWithState
val tuples = listOf("hello" X 1, "world" X 1)
val initialRDD = ssc.sparkContext().parallelize(tuples)
val tuples = arrayOf("hello" X 1, "world" X 1)
val initialRDD = ssc.sparkContext().rddOf(*tuples)

val lines = ssc.socketTextStream(
args.getOrElse(0) { DEFAULT_HOSTNAME },
Expand All @@ -95,7 +92,7 @@ object KotlinStatefulNetworkCount {
val stateDstream = wordsDstream.mapWithState(
StateSpec
.function(mappingFunc)
.initialState(initialRDD.toPairRDD())
.initialState(initialRDD.toJavaPairRDD())
)

stateDstream.print()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ internal class SparkIntegration : Integration() {
inline fun <reified T> RDD<T>.toDF(vararg colNames: String): Dataset<Row> = toDF(spark, *colNames)""".trimIndent(),
"""
inline fun <reified T> JavaRDDLike<T, *>.toDF(vararg colNames: String): Dataset<Row> = toDF(spark, *colNames)""".trimIndent(),
"""
fun <T> List<T>.toRDD(numSlices: Int = sc.defaultParallelism()): JavaRDD<T> = sc.toRDD(this, numSlices)""".trimIndent(),
"""
fun <T> rddOf(vararg elements: T, numSlices: Int = sc.defaultParallelism()): JavaRDD<T> = sc.toRDD(elements.toList(), numSlices)""".trimIndent(),
"""
val udf: UDFRegistration get() = spark.udf()""".trimIndent(),
).map(::execute)
Expand Down
Loading