Skip to content

Commit 5c47ba2

Browse files
TimotheeMathieurth
andauthored
Add CLARA Clustering algorithm (#83)
* add CLARA * add example * fix typo * add doc * fix docstring * add CLARA to test_common * add size check to pass tests * fix tests * update doc * add test consistency clara kmedoids * black * handle types KMedoids * Apply suggestions from code review Co-authored-by: Roman Yurchak <[email protected]> * correct 32 bit * change name variables * create private function inertia and changelog Co-authored-by: Roman Yurchak <[email protected]>
1 parent 445aaf8 commit 5c47ba2

File tree

8 files changed

+470
-33
lines changed

8 files changed

+470
-33
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Clustering
3232

3333
cluster.KMedoids
3434
cluster.CommonNNClustering
35+
cluster.CLARA
3536

3637
Robust
3738
====================

doc/changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Changelog
44
Unreleased
55
----------
66

7+
- Add `CLARA` (Clustering for Large Applications) which extends k-medoids to
8+
be more scalable using a sampling approach.
9+
[`#83 <https://github.com/scikit-learn-contrib/scikit-learn-extra/pull/83>`_].
710
- Fix `_estimator_type` for :class:`~sklearn_extra.robust` estimators. Fix
811
misbehavior of scikit-learn's :class:`~sklearn.model_selection.cross_val_score` and
912
:class:`~sklearn.grid_search.GridSearchCV` for :class:`~sklearn_extra.robust.RobustWeightedClassifier`

doc/modules/cluster.rst

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
.. _cluster:
22

3-
=====================================================
4-
Clustering with KMedoids and Common-nearest-neighbors
5-
=====================================================
3+
============================================================
4+
Clustering with KMedoids, CLARA and Common-nearest-neighbors
5+
============================================================
66
.. _k_medoids:
77

88
K-Medoids
@@ -82,6 +82,38 @@ when speed is an issue.
8282
for performing face recognition. International Journal of Soft Computing,
8383
Mathematics and Control, 3(3), pp 1-12.
8484

85+
86+
87+
CLARA
88+
=====
89+
90+
:class:`CLARA` is related to the :class:`KMedoids` algorithm. CLARA
91+
(Clustering for Large Applications) extends k-medoids to be more scalable,
92+
uses a sampling approach.
93+
94+
.. topic:: Examples:
95+
96+
* :ref:`sphx_glr_auto_examples_plot_clara_digits.py`: Applying K-Medoids on digits
97+
with various distance metrics.
98+
99+
100+
**Algorithm description:**
101+
CLARA uses random samples of the dataset, each of size `sampling_size`
102+
The algorith is iterative, first we select one sub-sample, then CLARA applies
103+
KMedoids on this sub-sample to obtain `n_clusters` medoids. At the next step,
104+
CLARA sample `sampling_size`-`n_clusters` from the dataset and the next sub-sample
105+
is composed of the best medoids found until now (with respect to inertia in the
106+
whole dataset, not the inertia only on the sub-sample) to which we add the new
107+
samples just drawn. Then, K-Medoids is applied to this new sub-sample, and loop
108+
back until `sample` sub-samples have been used.
109+
110+
111+
.. topic:: References:
112+
113+
* Kaufman, L. and Rousseeuw, P.J. (2008). Clustering Large Applications (Program CLARA).
114+
In Finding Groups in Data (eds L. Kaufman and P.J. Rousseeuw).
115+
doi:10.1002/9780470316801.ch2
116+
85117
.. _commonnn:
86118

87119
Common-nearest-neighbors clustering

examples/plot_clara_digits.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
======================================================================
3+
A demo of K-Medoids vs CLARA clustering on the handwritten digits data
4+
======================================================================
5+
In this example we compare different computation time of K-Medoids and CLARA on
6+
the handwritten digits data.
7+
"""
8+
import numpy as np
9+
import matplotlib.pyplot as plt
10+
import time
11+
12+
from sklearn_extra.cluster import KMedoids, CLARA
13+
from sklearn.datasets import load_digits
14+
from sklearn.decomposition import PCA
15+
from sklearn.preprocessing import scale
16+
17+
print(__doc__)
18+
19+
# Authors: Timo Erkkilä <[email protected]>
20+
# Antti Lehmussola <[email protected]>
21+
# Kornel Kiełczewski <[email protected]>
22+
# License: BSD 3 clause
23+
24+
np.random.seed(42)
25+
26+
digits = load_digits()
27+
data = scale(digits.data)
28+
n_digits = len(np.unique(digits.target))
29+
30+
reduced_data = PCA(n_components=2).fit_transform(data)
31+
32+
# Step size of the mesh. Decrease to increase the quality of the VQ.
33+
h = 0.02 # point in the mesh [x_min, m_max]x[y_min, y_max].
34+
35+
# Plot the decision boundary. For that, we will assign a color to each
36+
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
37+
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
38+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
39+
40+
plt.figure()
41+
plt.clf()
42+
43+
plt.suptitle(
44+
"Comparing KMedoids and CLARA",
45+
fontsize=14,
46+
)
47+
48+
49+
selected_models = [
50+
(
51+
KMedoids(metric="cosine", n_clusters=n_digits),
52+
"KMedoids (cosine)",
53+
),
54+
(
55+
KMedoids(metric="manhattan", n_clusters=n_digits),
56+
"KMedoids (manhattan)",
57+
),
58+
(
59+
CLARA(
60+
metric="cosine",
61+
n_clusters=n_digits,
62+
init="heuristic",
63+
n_sampling=50,
64+
),
65+
"CLARA (cosine)",
66+
),
67+
(
68+
CLARA(
69+
metric="manhattan",
70+
n_clusters=n_digits,
71+
init="heuristic",
72+
n_sampling=50,
73+
),
74+
"CLARA (manhattan)",
75+
),
76+
]
77+
78+
plot_rows = int(np.ceil(len(selected_models) / 2.0))
79+
plot_cols = 2
80+
81+
for i, (model, description) in enumerate(selected_models):
82+
83+
# Obtain labels for each point in mesh. Use last trained model.
84+
init_time = time.time()
85+
model.fit(reduced_data)
86+
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
87+
computation_time = time.time() - init_time
88+
89+
# Put the result into a color plot
90+
Z = Z.reshape(xx.shape)
91+
plt.subplot(plot_cols, plot_rows, i + 1)
92+
plt.imshow(
93+
Z,
94+
interpolation="nearest",
95+
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
96+
cmap=plt.cm.Paired,
97+
aspect="auto",
98+
origin="lower",
99+
)
100+
101+
plt.plot(
102+
reduced_data[:, 0], reduced_data[:, 1], "k.", markersize=2, alpha=0.3
103+
)
104+
# Plot the centroids as a white X
105+
centroids = model.cluster_centers_
106+
plt.scatter(
107+
centroids[:, 0],
108+
centroids[:, 1],
109+
marker="x",
110+
s=169,
111+
linewidths=3,
112+
color="w",
113+
zorder=10,
114+
)
115+
plt.title(description + ": %.2Fs" % (computation_time))
116+
plt.xlim(x_min, x_max)
117+
plt.ylim(y_min, y_max)
118+
plt.xticks(())
119+
plt.yticks(())
120+
121+
plt.show()

sklearn_extra/cluster/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._k_medoids import KMedoids
1+
from ._k_medoids import KMedoids, CLARA
22
from ._commonnn import commonnn, CommonNNClustering
33

4-
__all__ = ["KMedoids", "CommonNNClustering", "commonnn"]
4+
__all__ = ["KMedoids", "CLARA", "CommonNNClustering", "commonnn"]

0 commit comments

Comments
 (0)