Skip to content

Commit 4efb0fe

Browse files
committed
kmeans2: implementation: now match matlab behavior in case input is two points that are exactly the same (importantly, no crash)
1 parent 09e58ba commit 4efb0fe

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

src/I2MC/I2MC.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,13 @@ def steffen_interp(x, y, xi):
512512
class NotConvergedError(Exception):
513513
pass
514514

515+
def update_cluster_means(data, label, C):
516+
new_C, has_members = _vq.update_cluster_means(data, label, 2)
517+
if not has_members.all():
518+
# Set the empty clusters to their previous positions
519+
new_C[~has_members] = C[~has_members]
520+
return new_C, np.bincount(label, minlength=C.shape[0])
521+
515522
def kmeans2(data):
516523
# n points in p dimensional space
517524
n = data.shape[0]
@@ -529,18 +536,20 @@ def kmeans2(data):
529536
# second cluster
530537
D = cdist(C[:1,:], data, metric='sqeuclidean').min(axis=0)
531538
probs = D/D.sum()
532-
cumprobs = probs.cumsum()
533-
r = np.random.rand()
534-
C[1, :] = data[np.searchsorted(cumprobs, r)]
539+
edges = np.minimum(np.concatenate(([0], np.cumsum(probs))), 1) # protect against accumulated round-off
540+
edges[np.isnan(edges)] = 1. # protect against equidistant points
541+
edges[-1] = 1 # ensure upper edge is exactly 1
542+
ps = np.random.rand()
543+
C[1, :] = data[(edges[:-1] <= ps) & (ps < edges[1:])][0]
544+
535545

536546
# Compute the distance from every point to each cluster centroid and the
537547
# initial assignment of points to clusters
538548
D = cdist(C, data, metric='sqeuclidean')
539549
# Compute the nearest neighbor for each obs using the current code book
540550
label = vq(data, C)[0]
541551
# Update the code book by computing centroids
542-
C = _vq.update_cluster_means(data, label, 2)[0]
543-
m = np.bincount(label)
552+
C, m = update_cluster_means(data, label, C)
544553

545554
## Begin phase one: batch reassignments
546555
#-----------------------------------------------------
@@ -574,8 +583,7 @@ def kmeans2(data):
574583
label[lonely] = i
575584

576585
# Update clusters from which points are taken
577-
C = _vq.update_cluster_means(data, label, 2)[0]
578-
m = np.bincount(label)
586+
C, m = update_cluster_means(data, label, C)
579587
D = cdist(C, data, metric='sqeuclidean')
580588

581589
# Compute the total sum of distances for the current configuration.
@@ -605,8 +613,7 @@ def kmeans2(data):
605613
break
606614
label = new_label
607615
# update centers
608-
C = _vq.update_cluster_means(data, label, 2)[0]
609-
m = np.bincount(label)
616+
C, m = update_cluster_means(data, label, C)
610617

611618

612619
#------------------------------------------------------------------

0 commit comments

Comments
 (0)