diff --git a/setup.py b/setup.py index c691cd6d..d3029956 100755 --- a/setup.py +++ b/setup.py @@ -66,6 +66,11 @@ ["sklearn_extra/utils/_cyfht.pyx"], include_dirs=[np.get_include()], ), + Extension( + "sklearn_extra.cluster._k_medoids_helper", + ["sklearn_extra/cluster/_k_medoids_helper.pyx"], + include_dirs=[np.get_include()], + ), Extension( "sklearn_extra.robust._robust_weighted_estimator_helper", ["sklearn_extra/robust/_robust_weighted_estimator_helper.pyx"], diff --git a/sklearn_extra/cluster/_k_medoids.py b/sklearn_extra/cluster/_k_medoids.py index 4ed3ae8e..18fb987d 100644 --- a/sklearn_extra/cluster/_k_medoids.py +++ b/sklearn_extra/cluster/_k_medoids.py @@ -20,6 +20,9 @@ from sklearn.utils.validation import check_is_fitted from sklearn.exceptions import ConvergenceWarning +# cython implementation of swap step in PAM algorithm. +from ._k_medoids_helper import _compute_optimal_swap, _build + class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin): """k-medoids clustering. @@ -35,17 +38,27 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin): metric : string, or callable, optional, default: 'euclidean' What distance metric to use. See :func:metrics.pairwise_distances - init : {'random', 'heuristic', 'k-medoids++'}, optional, default: 'heuristic' + method : {'alternate', 'pam'}, default: 'alternate' + Which algorithm to use. + + init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'build' Specify medoid initialization method. 'random' selects n_clusters elements from the dataset. 'heuristic' picks the n_clusters points with the smallest sum distance to every other point. 'k-medoids++' follows an approach based on k-means++_, and in general, gives initial medoids which are more separated than those generated by the other methods. + 'build' is a greedy initialization of the medoids used in the original PAM + algorithm. Often 'build' is more efficient but slower than other + initializations on big datasets and it is also very non-robust, + if there are outliers in the dataset, use another initialization. .. _k-means++: https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf max_iter : int, optional, default : 300 - Specify the maximum number of iterations when fitting. + Specify the maximum number of iterations when fitting. It can be zero in + which case only the initialization is computed which may be suitable for + large datasets when the initialization is sufficiently efficient + (i.e. for 'build' init). random_state : int, RandomState instance or None, optional Specify random state for the random number generator. Used to @@ -112,24 +125,25 @@ def __init__( self, n_clusters=8, metric="euclidean", + method="alternate", init="heuristic", max_iter=300, random_state=None, ): self.n_clusters = n_clusters self.metric = metric + self.method = method self.init = init self.max_iter = max_iter self.random_state = random_state - def _check_nonnegative_int(self, value, desc): + def _check_nonnegative_int(self, value, desc, strict=True): """Validates if value is a valid integer > 0""" - - if ( - value is None - or value <= 0 - or not isinstance(value, (int, np.integer)) - ): + if strict: + negative = (value is None) or (value <= 0) + else: + negative = (value is None) or (value < 0) + if negative or not isinstance(value, (int, np.integer)): raise ValueError( "%s should be a nonnegative integer. " "%s was given" % (desc, value) @@ -140,10 +154,10 @@ def _check_init_args(self): # Check n_clusters and max_iter self._check_nonnegative_int(self.n_clusters, "n_clusters") - self._check_nonnegative_int(self.max_iter, "max_iter") + self._check_nonnegative_int(self.max_iter, "max_iter", False) # Check init - init_methods = ["random", "heuristic", "k-medoids++"] + init_methods = ["random", "heuristic", "k-medoids++", "build"] if self.init not in init_methods: raise ValueError( "init needs to be one of " @@ -183,15 +197,44 @@ def fit(self, X, y=None): ) labels = None + if self.method == "pam": + # Compute the distance to the first and second closest points + # among medoids. + Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]] + # Continue the algorithm as long as # the medoids keep changing and the maximum number # of iterations is not exceeded + for self.n_iter_ in range(0, self.max_iter): old_medoid_idxs = np.copy(medoid_idxs) labels = np.argmin(D[medoid_idxs, :], axis=0) - # Update medoids with the new cluster indices - self._update_medoid_idxs_in_place(D, labels, medoid_idxs) + if self.method == "alternate": + # Update medoids with the new cluster indices + self._update_medoid_idxs_in_place(D, labels, medoid_idxs) + elif self.method == "pam": + not_medoid_idxs = np.delete(np.arange(len(D)), medoid_idxs) + optimal_swap = _compute_optimal_swap( + D, + medoid_idxs.astype(np.intc), + not_medoid_idxs.astype(np.intc), + Djs, + Ejs, + self.n_clusters, + ) + if optimal_swap is not None: + i, j, _ = optimal_swap + medoid_idxs[medoid_idxs == i] = j + + # update Djs and Ejs with new medoids + Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]] + else: + raise ValueError( + f"method={self.method} is not supported. Supported methods " + f"are 'pam' and 'alternate'." + ) + if np.all(old_medoid_idxs == medoid_idxs): break elif self.n_iter_ == self.max_iter - 1: @@ -210,7 +253,7 @@ def fit(self, X, y=None): # Expose labels_ which are the assignments of # the training data to clusters - self.labels_ = labels + self.labels_ = np.argmin(D[medoid_idxs, :], axis=0) self.medoid_indices_ = medoid_idxs self.inertia_ = self._compute_inertia(self.transform(X)) @@ -252,6 +295,10 @@ def _update_medoid_idxs_in_place(self, D, labels, medoid_idxs): if min_cost < curr_cost: medoid_idxs[k] = cluster_k_idxs[min_cost_idx] + def _compute_cost(self, D, medoid_idxs): + """ Compute the cose for a given configuration of the medoids""" + return self._compute_inertia(D[:, medoid_idxs]) + def transform(self, X): """Transforms X to cluster-distance space. @@ -339,6 +386,8 @@ def _initialize_medoids(self, D, n_clusters, random_state_): medoids = np.argpartition(np.sum(D, axis=1), n_clusters - 1)[ :n_clusters ] + elif self.init == "build": # Build initialization + medoids = _build(D, n_clusters).astype(np.int64) else: raise ValueError(f"init value '{self.init}' not recognized") diff --git a/sklearn_extra/cluster/_k_medoids_helper.pyx b/sklearn_extra/cluster/_k_medoids_helper.pyx new file mode 100644 index 00000000..65a99234 --- /dev/null +++ b/sklearn_extra/cluster/_k_medoids_helper.pyx @@ -0,0 +1,110 @@ +# cython: infer_types=True +# Fast swap step and build step in PAM algorithm for k_medoid. +# Author: Timothée Mathieu +# License: 3-clause BSD + +cimport cython + +import numpy as np +cimport numpy as np +from cython cimport floating, integral + +@cython.boundscheck(False) # Deactivate bounds checking +def _compute_optimal_swap( floating[:,:] D, + int[:] medoid_idxs, + int[:] not_medoid_idxs, + floating[:] Djs, + floating[:] Ejs, + int n_clusters): + """Compute best cost change for all the possible swaps.""" + + # Initialize best cost change and the associated swap couple. + cdef (int, int, floating) best_cost_change = (1, 1, 0.0) + cdef int sample_size = len(D) + cdef int i, j, h, id_i, id_h, id_j + cdef floating cost_change + cdef int not_medoid_shape = sample_size - n_clusters + cdef bint cluster_i_bool, not_cluster_i_bool, second_best_medoid + cdef bint not_second_best_medoid + + # Compute the change in cost for each swap. + for h in range(not_medoid_shape): + # id of the potential new medoid. + id_h = not_medoid_idxs[h] + for i in range(n_clusters): + # id of the medoid we want to replace. + id_i = medoid_idxs[i] + cost_change = 0.0 + # compute for all not-selected points the change in cost + for j in range(not_medoid_shape): + id_j = not_medoid_idxs[j] + cluster_i_bool = D[id_i, id_j] == Djs[id_j] + not_cluster_i_bool = D[id_i, id_j] != Djs[id_j] + second_best_medoid = D[id_h, id_j] < Ejs[id_j] + not_second_best_medoid = D[id_h, id_j] >= Ejs[id_j] + + if cluster_i_bool & second_best_medoid: + cost_change += D[id_j, id_h] - Djs[id_j] + elif cluster_i_bool & not_second_best_medoid: + cost_change += Ejs[id_j] - Djs[id_j] + elif not_cluster_i_bool & (D[id_j, id_h] < Djs[id_j]): + cost_change += D[id_j, id_h] - Djs[id_j] + + # same for i + second_best_medoid = D[id_h, id_i] < Ejs[id_i] + if second_best_medoid: + cost_change += D[id_i, id_h] + else: + cost_change += Ejs[id_i] + + if cost_change < best_cost_change[2]: + best_cost_change = (id_i, id_h, cost_change) + + # If one of the swap decrease the objective, return that swap. + if best_cost_change[2] < 0: + return best_cost_change + else: + return None + + + + +def _build( floating[:, :] D, int n_clusters): + """Compute BUILD initialization, a greedy medoid initialization.""" + + cdef int[:] medoid_idxs = np.zeros(n_clusters, dtype = np.intc) + cdef int sample_size = len(D) + cdef int[:] not_medoid_idxs = np.zeros(sample_size, dtype = np.intc) + cdef int i, j, id_i, id_j + + medoid_idxs[0] = np.argmin(np.sum(D,axis=0)) + not_medoid_idxs = np.delete(not_medoid_idxs, medoid_idxs[0]) + + cdef int n_medoids_current = 1 + + cdef floating[:] Dj = D[medoid_idxs[0]].copy() + cdef floating cost_change + cdef (int, int) new_medoid = (medoid_idxs[0], 0) + cdef floating cost_change_max + + for _ in range(n_clusters -1): + cost_change_max = 0 + for i in range(sample_size - n_medoids_current): + id_i = not_medoid_idxs[i] + cost_change = 0 + for j in range(sample_size - n_medoids_current): + id_j = not_medoid_idxs[j] + cost_change += max(0, Dj[id_j] - D[id_i, id_j]) + if cost_change >= cost_change_max: + cost_change_max = cost_change + new_medoid = (id_i, i) + + + medoid_idxs[n_medoids_current] = new_medoid[0] + n_medoids_current += 1 + not_medoid_idxs = np.delete(not_medoid_idxs, new_medoid[1]) + + + for id_j in range(sample_size): + Dj[id_j] = min(Dj[id_j], D[id_j, new_medoid[0]]) + return np.array(medoid_idxs) diff --git a/sklearn_extra/cluster/tests/test_k_medoids.py b/sklearn_extra/cluster/tests/test_k_medoids.py index 919aab5a..f6854ee5 100644 --- a/sklearn_extra/cluster/tests/test_k_medoids.py +++ b/sklearn_extra/cluster/tests/test_k_medoids.py @@ -12,10 +12,46 @@ from sklearn_extra.cluster import KMedoids from sklearn.cluster import KMeans +from sklearn.datasets import make_blobs + seed = 0 X = np.random.RandomState(seed).rand(100, 5) +# test kmedoid's results +rng = np.random.RandomState(seed) +X_cc, y_cc = make_blobs( + n_samples=100, + centers=np.array([[-1, -1], [1, 1]]), + random_state=rng, + shuffle=False, +) + + +@pytest.mark.parametrize("method", ["alternate", "pam"]) +@pytest.mark.parametrize( + "init", ["random", "heuristic", "build", "k-medoids++"] +) +def test_kmedoid_results(method, init): + expected = np.hstack([np.zeros(50), np.ones(50)]) + km = KMedoids(n_clusters=2, init=init, method=method) + km.fit(X_cc) + # This test use data that are not perfectly separable so the + # accuracy is not 1. Accuracy around 0.85 + assert (np.mean(km.labels_ == expected) > 0.8) or ( + 1 - np.mean(km.labels_ == expected) > 0.8 + ) + + +def test_medoids_invalid_method(): + with pytest.raises(ValueError, match="invalid is not supported"): + KMedoids(n_clusters=1, method="invalid").fit([[0, 1], [1, 1]]) + + +def test_medoids_invalid_init(): + with pytest.raises(ValueError, match="init needs to be one of"): + KMedoids(n_clusters=1, init="invalid").fit([[0, 1], [1, 1]]) + def test_kmedoids_input_validation_and_fit_check(): rng = np.random.RandomState(seed) @@ -28,9 +64,9 @@ def test_kmedoids_input_validation_and_fit_check(): with pytest.raises(ValueError, match=msg): KMedoids(n_clusters=None).fit(X) - msg = "max_iter should be a nonnegative integer. 0 was given" + msg = "max_iter should be a nonnegative integer. -1 was given" with pytest.raises(ValueError, match=msg): - KMedoids(n_clusters=1, max_iter=0).fit(X) + KMedoids(n_clusters=1, max_iter=-1).fit(X) msg = "max_iter should be a nonnegative integer. None was given" with pytest.raises(ValueError, match=msg):