Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from pyspark import since, keyword_only
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol, \
HasFeaturesCol
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import JavaMLReadable, JavaMLWritable

__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
'MulticlassClassificationEvaluator']
'MulticlassClassificationEvaluator', 'ClusteringEvaluator']


@inherit_doc
Expand Down Expand Up @@ -328,6 +329,77 @@ def setParams(self, predictionCol="prediction", labelCol="label",
kwargs = self._input_kwargs
return self._set(**kwargs)


@inherit_doc
class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Evaluator for Clustering results, which expects two input
columns: prediction and features.

>>> from pyspark.ml.linalg import Vectors
>>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
... [([0.0, 0.5], 0.0), ([0.5, 0.0], 0.0), ([10.0, 11.0], 1.0),
... ([10.5, 11.5], 1.0), ([1.0, 1.0], 0.0), ([8.0, 6.0], 1.0)])
>>> dataset = spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
...
>>> evaluator = ClusteringEvaluator(predictionCol="prediction")
>>> evaluator.evaluate(dataset)
0.9079...
>>> ce_path = temp_path + "/ce"
>>> evaluator.save(ce_path)
>>> evaluator2 = ClusteringEvaluator.load(ce_path)
>>> str(evaluator2.getPredictionCol())
'prediction'

.. versionadded:: 2.3.0
"""
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (silhouette)",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, predictionCol="prediction", featuresCol="features",
metricName="silhouette"):
"""
__init__(self, predictionCol="prediction", featuresCol="features", \
metricName="silhouette")
"""
super(ClusteringEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
self._setDefault(metricName="silhouette")
kwargs = self._input_kwargs
self._set(**kwargs)

@since("2.3.0")
def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
return self._set(metricName=value)

@since("2.3.0")
def getMetricName(self):
"""
Gets the value of metricName or its default value.
"""
return self.getOrDefault(self.metricName)

@keyword_only
@since("2.3.0")
def setParams(self, predictionCol="prediction", featuresCol="features",
metricName="silhouette"):
"""
setParams(self, predictionCol="prediction", featuresCol="features", \
metricName="silhouette")
Sets params for clustering evaluator.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)

if __name__ == "__main__":
import doctest
import tempfile
Expand Down