Skip to content

Commit 8070a63

Browse files
authored
Merge pull request #162 from Kotlin/udt-support
Udt support, DF functions, mllib example
2 parents 585719f + d44bd8c commit 8070a63

File tree

14 files changed

+626
-15
lines changed

14 files changed

+626
-15
lines changed

examples/pom-3.2_2.12.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
<artifactId>spark-streaming_${scala.compat.version}</artifactId>
3030
<version>${spark3.version}</version>
3131
</dependency>
32+
<dependency>
33+
<groupId>org.apache.spark</groupId>
34+
<artifactId>spark-mllib_${scala.compat.version}</artifactId>
35+
<version>${spark3.version}</version>
36+
</dependency>
3237
<dependency><!-- Only needed for Qodana -->
3338
<groupId>org.apache.spark</groupId>
3439
<artifactId>spark-streaming-kafka-0-10_${scala.compat.version}</artifactId>
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*-
2+
* =LICENSE=
3+
* Kotlin Spark API: Examples for Spark 3.2+ (Scala 2.12)
4+
* ----------
5+
* Copyright (C) 2019 - 2022 JetBrains
6+
* ----------
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* =LICENSEEND=
19+
*/
20+
package org.jetbrains.kotlinx.spark.examples
21+
22+
import org.apache.spark.ml.Pipeline
23+
import org.apache.spark.ml.PipelineModel
24+
import org.apache.spark.ml.classification.LogisticRegression
25+
import org.apache.spark.ml.classification.LogisticRegressionModel
26+
import org.apache.spark.ml.feature.HashingTF
27+
import org.apache.spark.ml.feature.Tokenizer
28+
import org.apache.spark.ml.linalg.Matrix
29+
import org.apache.spark.ml.linalg.Vector
30+
import org.apache.spark.ml.linalg.Vectors
31+
import org.apache.spark.ml.param.ParamMap
32+
import org.apache.spark.ml.stat.ChiSquareTest
33+
import org.apache.spark.ml.stat.Correlation
34+
import org.apache.spark.ml.stat.Summarizer.*
35+
import org.apache.spark.sql.Dataset
36+
import org.apache.spark.sql.Row
37+
import org.apache.spark.sql.functions.col
38+
import org.jetbrains.kotlinx.spark.api.KSparkSession
39+
import org.jetbrains.kotlinx.spark.api.to
40+
import org.jetbrains.kotlinx.spark.api.tuples.t
41+
import org.jetbrains.kotlinx.spark.api.tuples.tupleOf
42+
import org.jetbrains.kotlinx.spark.api.withSpark
43+
44+
45+
fun main() = withSpark {
46+
// https://spark.apache.org/docs/latest/ml-statistics.html
47+
correlation()
48+
chiSquare()
49+
summarizer()
50+
51+
// https://spark.apache.org/docs/latest/ml-pipeline.html
52+
estimatorTransformerParam()
53+
pipeline()
54+
}
55+
56+
private fun KSparkSession.correlation() {
57+
println("Correlation:")
58+
59+
val data = listOf(
60+
Vectors.sparse(4, intArrayOf(0, 3), doubleArrayOf(1.0, -2.0)),
61+
Vectors.dense(4.0, 5.0, 0.0, 3.0),
62+
Vectors.dense(6.0, 7.0, 0.0, 8.0),
63+
Vectors.sparse(4, intArrayOf(0, 3), doubleArrayOf(9.0, 1.0))
64+
).map(::tupleOf)
65+
66+
val df = data.toDF("features")
67+
68+
val r1 = Correlation.corr(df, "features").head().getAs<Matrix>(0)
69+
println(
70+
"""
71+
|Pearson correlation matrix:
72+
|$r1
73+
|
74+
""".trimMargin()
75+
)
76+
77+
val r2 = Correlation.corr(df, "features", "spearman").head().getAs<Matrix>(0)
78+
println(
79+
"""
80+
|Spearman correlation matrix:
81+
|$r2
82+
|
83+
""".trimMargin()
84+
)
85+
}
86+
87+
private fun KSparkSession.chiSquare() {
88+
println("ChiSquare:")
89+
90+
val data = listOf(
91+
t(0.0, Vectors.dense(0.5, 10.0)),
92+
t(0.0, Vectors.dense(1.5, 20.0)),
93+
t(1.0, Vectors.dense(1.5, 30.0)),
94+
t(0.0, Vectors.dense(3.5, 30.0)),
95+
t(0.0, Vectors.dense(3.5, 40.0)),
96+
t(1.0, Vectors.dense(3.5, 40.0)),
97+
)
98+
99+
// while df.getAs<Something>(0) works, it's often easier to just parse the result as a typed Dataset :)
100+
data class ChiSquareTestResult(
101+
val pValues: Vector,
102+
val degreesOfFreedom: List<Int>,
103+
val statistics: Vector,
104+
)
105+
106+
val df: Dataset<Row> = data.toDF("label", "features")
107+
val chi = ChiSquareTest.test(df, "features", "label")
108+
.to<ChiSquareTestResult>()
109+
.head()
110+
111+
println("pValues: ${chi.pValues}")
112+
println("degreesOfFreedom: ${chi.degreesOfFreedom}")
113+
println("statistics: ${chi.statistics}")
114+
println()
115+
}
116+
117+
private fun KSparkSession.summarizer() {
118+
println("Summarizer:")
119+
120+
val data = listOf(
121+
t(Vectors.dense(2.0, 3.0, 5.0), 1.0),
122+
t(Vectors.dense(4.0, 6.0, 7.0), 2.0)
123+
)
124+
125+
val df = data.toDF("features", "weight")
126+
127+
val result1 = df
128+
.select(
129+
metrics("mean", "variance")
130+
.summary(col("features"), col("weight")).`as`("summary")
131+
)
132+
.select("summary.mean", "summary.variance")
133+
.first()
134+
135+
println("with weight: mean = ${result1.getAs<Vector>(0)}, variance = ${result1.getAs<Vector>(1)}")
136+
137+
val result2 = df
138+
.select(
139+
mean(col("features")),
140+
variance(col("features")),
141+
)
142+
.first()
143+
144+
println("without weight: mean = ${result2.getAs<Vector>(0)}, variance = ${result2.getAs<Vector>(1)}")
145+
println()
146+
}
147+
148+
private fun KSparkSession.estimatorTransformerParam() {
149+
println("Estimator, Transformer, and Param")
150+
151+
// Prepare training data from a list of (label, features) tuples.
152+
val training = listOf(
153+
t(1.0, Vectors.dense(0.0, 1.1, 0.1)),
154+
t(0.0, Vectors.dense(2.0, 1.0, -1.0)),
155+
t(0.0, Vectors.dense(2.0, 1.3, 1.0)),
156+
t(1.0, Vectors.dense(0.0, 1.2, -0.5))
157+
).toDF("label", "features")
158+
159+
// Create a LogisticRegression instance. This instance is an Estimator.
160+
val lr = LogisticRegression()
161+
162+
// Print out the parameters, documentation, and any default values.
163+
println("LogisticRegression parameters:\n ${lr.explainParams()}\n")
164+
165+
// We may set parameters using setter methods.
166+
lr.apply {
167+
maxIter = 10
168+
regParam = 0.01
169+
}
170+
171+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
172+
val model1 = lr.fit(training)
173+
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
174+
// we can view the parameters it used during fit().
175+
// This prints the parameter (name: value) pairs, where names are unique IDs for this
176+
// LogisticRegression instance.
177+
println("Model 1 was fit using parameters: ${model1.parent().extractParamMap()}")
178+
179+
// We may alternatively specify parameters using a ParamMap.
180+
val paramMap = ParamMap()
181+
.put(lr.maxIter().w(20)) // Specify 1 Param.
182+
.put(lr.maxIter(), 30) // This overwrites the original maxIter.
183+
.put(lr.regParam().w(0.1), lr.threshold().w(0.55)) // Specify multiple Params.
184+
185+
// One can also combine ParamMaps.
186+
val paramMap2 = ParamMap()
187+
.put(lr.probabilityCol().w("myProbability")) // Change output column name
188+
189+
val paramMapCombined = paramMap.`$plus$plus`(paramMap2)
190+
191+
// Now learn a new model using the paramMapCombined parameters.
192+
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
193+
val model2: LogisticRegressionModel = lr.fit(training, paramMapCombined)
194+
println("Model 2 was fit using parameters: ${model2.parent().extractParamMap()}")
195+
196+
// Prepare test documents.
197+
val test = listOf(
198+
t(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
199+
t(0.0, Vectors.dense(3.0, 2.0, -0.1)),
200+
t(1.0, Vectors.dense(0.0, 2.2, -1.5)),
201+
).toDF("label", "features")
202+
203+
// Make predictions on test documents using the Transformer.transform() method.
204+
// LogisticRegression.transform will only use the 'features' column.
205+
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
206+
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
207+
val results = model2.transform(test)
208+
val rows = results.select("features", "label", "myProbability", "prediction")
209+
for (r: Row in rows.collectAsList())
210+
println("(${r[0]}, ${r[1]}) -> prob=${r[2]}, prediction=${r[3]}")
211+
212+
println()
213+
}
214+
215+
private fun KSparkSession.pipeline() {
216+
println("Pipeline:")
217+
// Prepare training documents from a list of (id, text, label) tuples.
218+
val training = listOf(
219+
t(0L, "a b c d e spark", 1.0),
220+
t(1L, "b d", 0.0),
221+
t(2L, "spark f g h", 1.0),
222+
t(3L, "hadoop mapreduce", 0.0)
223+
).toDF("id", "text", "label")
224+
225+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
226+
val tokenizer = Tokenizer()
227+
.setInputCol("text")
228+
.setOutputCol("words")
229+
val hashingTF = HashingTF()
230+
.setNumFeatures(1000)
231+
.setInputCol(tokenizer.outputCol)
232+
.setOutputCol("features")
233+
val lr = LogisticRegression()
234+
.setMaxIter(10)
235+
.setRegParam(0.001)
236+
val pipeline = Pipeline()
237+
.setStages(
238+
arrayOf(
239+
tokenizer,
240+
hashingTF,
241+
lr,
242+
)
243+
)
244+
245+
// Fit the pipeline to training documents.
246+
val model = pipeline.fit(training)
247+
248+
// Now we can optionally save the fitted pipeline to disk
249+
model.write().overwrite().save("/tmp/spark-logistic-regression-model")
250+
251+
// We can also save this unfit pipeline to disk
252+
pipeline.write().overwrite().save("/tmp/unfit-lr-model")
253+
254+
// And load it back in during production
255+
val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")
256+
257+
// Prepare test documents, which are unlabeled (id, text) tuples.
258+
val test = listOf(
259+
t(4L, "spark i j k"),
260+
t(5L, "l m n"),
261+
t(6L, "spark hadoop spark"),
262+
t(7L, "apache hadoop"),
263+
).toDF("id", "text")
264+
265+
// Make predictions on test documents.
266+
val predictions = model.transform(test)
267+
.select("id", "text", "probability", "prediction")
268+
.collectAsList()
269+
270+
for (r in predictions)
271+
println("(${r[0]}, ${r[1]}) --> prob=${r[2]}, prediction=${r[3]}")
272+
273+
println()
274+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*-
2+
* =LICENSE=
3+
* Kotlin Spark API: Examples for Spark 3.2+ (Scala 2.12)
4+
* ----------
5+
* Copyright (C) 2019 - 2022 JetBrains
6+
* ----------
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* =LICENSEEND=
19+
*/
20+
import org.apache.hadoop.shaded.com.google.common.base.MoreObjects
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
23+
import org.apache.spark.sql.types.*
24+
import org.apache.spark.unsafe.types.UTF8String
25+
import org.jetbrains.kotlinx.spark.api.*
26+
import org.jetbrains.kotlinx.spark.api.tuples.tupleOf
27+
import java.io.Serializable
28+
import kotlin.reflect.jvm.jvmName
29+
30+
class CityUserDefinedType : UserDefinedType<City>() {
31+
32+
override fun sqlType(): DataType = DATA_TYPE
33+
34+
override fun serialize(city: City): InternalRow = GenericInternalRow(2).apply {
35+
setInt(DEPT_NUMBER_INDEX, city.departmentNumber)
36+
update(NAME_INDEX, UTF8String.fromString(city.name))
37+
}
38+
39+
override fun deserialize(datum: Any): City =
40+
if (datum is InternalRow)
41+
City(
42+
name = datum.getString(NAME_INDEX),
43+
departmentNumber = datum.getInt(DEPT_NUMBER_INDEX),
44+
)
45+
else throw IllegalStateException("Unsupported conversion")
46+
47+
override fun userClass(): Class<City> = City::class.java
48+
49+
companion object {
50+
private const val DEPT_NUMBER_INDEX = 0
51+
private const val NAME_INDEX = 1
52+
private val DATA_TYPE = DataTypes.createStructType(
53+
arrayOf(
54+
DataTypes.createStructField(
55+
"departmentNumber",
56+
DataTypes.IntegerType,
57+
false,
58+
MetadataBuilder().putLong("maxNumber", 99).build(),
59+
),
60+
DataTypes.createStructField("name", DataTypes.StringType, false)
61+
)
62+
)
63+
}
64+
}
65+
66+
@SQLUserDefinedType(udt = CityUserDefinedType::class)
67+
class City(val name: String, val departmentNumber: Int) : Serializable {
68+
69+
override fun toString(): String =
70+
MoreObjects
71+
.toStringHelper(this)
72+
.add("name", name)
73+
.add("departmentNumber", departmentNumber)
74+
.toString()
75+
}
76+
77+
fun main() = withSpark {
78+
79+
// UDTRegistration.register(City::class.jvmName, CityUserDefinedType::class.jvmName)
80+
81+
val items = listOf(
82+
City("Amsterdam", 1),
83+
City("Breda", 2),
84+
City("Oosterhout", 3),
85+
).map(::tupleOf)
86+
87+
val ds = items.toDS()
88+
ds.showDS()
89+
}

0 commit comments

Comments
 (0)