Skip to content

Commit 20dbd01

Browse files
committed
more samples, can probably safely say all works
1 parent 5de11c6 commit 20dbd01

File tree

1 file changed

+145
-5
lines changed
  • examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples

1 file changed

+145
-5
lines changed

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/MLlib.kt

Lines changed: 145 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,38 @@
1919
*/
2020
package org.jetbrains.kotlinx.spark.examples
2121

22-
import org.apache.spark.*
22+
import org.apache.spark.ml.Pipeline
23+
import org.apache.spark.ml.PipelineModel
24+
import org.apache.spark.ml.classification.LogisticRegression
25+
import org.apache.spark.ml.classification.LogisticRegressionModel
26+
import org.apache.spark.ml.feature.HashingTF
27+
import org.apache.spark.ml.feature.Tokenizer
2328
import org.apache.spark.ml.linalg.Matrix
2429
import org.apache.spark.ml.linalg.Vector
2530
import org.apache.spark.ml.linalg.Vectors
31+
import org.apache.spark.ml.param.ParamMap
2632
import org.apache.spark.ml.stat.ChiSquareTest
2733
import org.apache.spark.ml.stat.Correlation
28-
import org.apache.spark.ml.stat.Summarizer
2934
import org.apache.spark.ml.stat.Summarizer.*
3035
import org.apache.spark.sql.Dataset
3136
import org.apache.spark.sql.Row
3237
import org.apache.spark.sql.functions.col
33-
import org.jetbrains.kotlinx.spark.api.*
34-
import org.jetbrains.kotlinx.spark.api.tuples.*
35-
import scala.collection.mutable.WrappedArray
38+
import org.jetbrains.kotlinx.spark.api.KSparkSession
39+
import org.jetbrains.kotlinx.spark.api.to
40+
import org.jetbrains.kotlinx.spark.api.tuples.t
41+
import org.jetbrains.kotlinx.spark.api.tuples.tupleOf
42+
import org.jetbrains.kotlinx.spark.api.withSpark
3643

3744

3845
fun main() = withSpark {
46+
// https://spark.apache.org/docs/latest/ml-statistics.html
3947
correlation()
4048
chiSquare()
4149
summarizer()
50+
51+
// https://spark.apache.org/docs/latest/ml-pipeline.html
52+
estimatorTransformerParam()
53+
pipeline()
4254
}
4355

4456
private fun KSparkSession.correlation() {
@@ -131,4 +143,132 @@ private fun KSparkSession.summarizer() {
131143

132144
println("without weight: mean = ${result2.getAs<Vector>(0)}, variance = ${result2.getAs<Vector>(1)}")
133145
println()
146+
}
147+
148+
private fun KSparkSession.estimatorTransformerParam() {
149+
println("Estimator, Transformer, and Param")
150+
151+
// Prepare training data from a list of (label, features) tuples.
152+
val training = listOf(
153+
t(1.0, Vectors.dense(0.0, 1.1, 0.1)),
154+
t(0.0, Vectors.dense(2.0, 1.0, -1.0)),
155+
t(0.0, Vectors.dense(2.0, 1.3, 1.0)),
156+
t(1.0, Vectors.dense(0.0, 1.2, -0.5))
157+
).toDF("label", "features")
158+
159+
// Create a LogisticRegression instance. This instance is an Estimator.
160+
val lr = LogisticRegression()
161+
162+
// Print out the parameters, documentation, and any default values.
163+
println("LogisticRegression parameters:\n ${lr.explainParams()}\n")
164+
165+
// We may set parameters using setter methods.
166+
lr.apply {
167+
maxIter = 10
168+
regParam = 0.01
169+
}
170+
171+
// Learn a LogisticRegression model. This uses the parameters stored in lr.
172+
val model1 = lr.fit(training)
173+
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
174+
// we can view the parameters it used during fit().
175+
// This prints the parameter (name: value) pairs, where names are unique IDs for this
176+
// LogisticRegression instance.
177+
println("Model 1 was fit using parameters: ${model1.parent().extractParamMap()}")
178+
179+
// We may alternatively specify parameters using a ParamMap.
180+
val paramMap = ParamMap()
181+
.put(lr.maxIter().w(20)) // Specify 1 Param.
182+
.put(lr.maxIter(), 30) // This overwrites the original maxIter.
183+
.put(lr.regParam().w(0.1), lr.threshold().w(0.55)) // Specify multiple Params.
184+
185+
// One can also combine ParamMaps.
186+
val paramMap2 = ParamMap()
187+
.put(lr.probabilityCol().w("myProbability")) // Change output column name
188+
189+
val paramMapCombined = paramMap.`$plus$plus`(paramMap2)
190+
191+
// Now learn a new model using the paramMapCombined parameters.
192+
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
193+
val model2: LogisticRegressionModel = lr.fit(training, paramMapCombined)
194+
println("Model 2 was fit using parameters: ${model2.parent().extractParamMap()}")
195+
196+
// Prepare test documents.
197+
val test = listOf(
198+
t(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
199+
t(0.0, Vectors.dense(3.0, 2.0, -0.1)),
200+
t(1.0, Vectors.dense(0.0, 2.2, -1.5)),
201+
).toDF("label", "features")
202+
203+
// Make predictions on test documents using the Transformer.transform() method.
204+
// LogisticRegression.transform will only use the 'features' column.
205+
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
206+
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
207+
val results = model2.transform(test)
208+
val rows = results.select("features", "label", "myProbability", "prediction")
209+
for (r: Row in rows.collectAsList())
210+
println("(${r[0]}, ${r[1]}) -> prob=${r[2]}, prediction=${r[3]}")
211+
212+
println()
213+
}
214+
215+
private fun KSparkSession.pipeline() {
216+
println("Pipeline:")
217+
// Prepare training documents from a list of (id, text, label) tuples.
218+
val training = listOf(
219+
t(0L, "a b c d e spark", 1.0),
220+
t(1L, "b d", 0.0),
221+
t(2L, "spark f g h", 1.0),
222+
t(3L, "hadoop mapreduce", 0.0)
223+
).toDF("id", "text", "label")
224+
225+
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
226+
val tokenizer = Tokenizer()
227+
.setInputCol("text")
228+
.setOutputCol("words")
229+
val hashingTF = HashingTF()
230+
.setNumFeatures(1000)
231+
.setInputCol(tokenizer.outputCol)
232+
.setOutputCol("features")
233+
val lr = LogisticRegression()
234+
.setMaxIter(10)
235+
.setRegParam(0.001)
236+
val pipeline = Pipeline()
237+
.setStages(
238+
arrayOf(
239+
tokenizer,
240+
hashingTF,
241+
lr,
242+
)
243+
)
244+
245+
// Fit the pipeline to training documents.
246+
val model = pipeline.fit(training)
247+
248+
// Now we can optionally save the fitted pipeline to disk
249+
model.write().overwrite().save("/tmp/spark-logistic-regression-model")
250+
251+
// We can also save this unfit pipeline to disk
252+
pipeline.write().overwrite().save("/tmp/unfit-lr-model")
253+
254+
// And load it back in during production
255+
val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")
256+
257+
// Prepare test documents, which are unlabeled (id, text) tuples.
258+
val test = listOf(
259+
t(4L, "spark i j k"),
260+
t(5L, "l m n"),
261+
t(6L, "spark hadoop spark"),
262+
t(7L, "apache hadoop"),
263+
).toDF("id", "text")
264+
265+
// Make predictions on test documents.
266+
val predictions = model.transform(test)
267+
.select("id", "text", "probability", "prediction")
268+
.collectAsList()
269+
270+
for (r in predictions)
271+
println("(${r[0]}, ${r[1]}) --> prob=${r[2]}, prediction=${r[3]}")
272+
273+
println()
134274
}

0 commit comments

Comments
 (0)