@@ -512,6 +512,13 @@ def steffen_interp(x, y, xi):
512512class NotConvergedError (Exception ):
513513 pass
514514
515+ def update_cluster_means (data , label , C ):
516+ new_C , has_members = _vq .update_cluster_means (data , label , 2 )
517+ if not has_members .all ():
518+ # Set the empty clusters to their previous positions
519+ new_C [~ has_members ] = C [~ has_members ]
520+ return new_C , np .bincount (label , minlength = C .shape [0 ])
521+
515522def kmeans2 (data ):
516523 # n points in p dimensional space
517524 n = data .shape [0 ]
@@ -529,18 +536,20 @@ def kmeans2(data):
529536 # second cluster
530537 D = cdist (C [:1 ,:], data , metric = 'sqeuclidean' ).min (axis = 0 )
531538 probs = D / D .sum ()
532- cumprobs = probs .cumsum ()
533- r = np .random .rand ()
534- C [1 , :] = data [np .searchsorted (cumprobs , r )]
539+ edges = np .minimum (np .concatenate (([0 ], np .cumsum (probs ))), 1 ) # protect against accumulated round-off
540+ edges [np .isnan (edges )] = 1. # protect against equidistant points
541+ edges [- 1 ] = 1 # ensure upper edge is exactly 1
542+ ps = np .random .rand ()
543+ C [1 , :] = data [(edges [:- 1 ] <= ps ) & (ps < edges [1 :])][0 ]
544+
535545
536546 # Compute the distance from every point to each cluster centroid and the
537547 # initial assignment of points to clusters
538548 D = cdist (C , data , metric = 'sqeuclidean' )
539549 # Compute the nearest neighbor for each obs using the current code book
540550 label = vq (data , C )[0 ]
541551 # Update the code book by computing centroids
542- C = _vq .update_cluster_means (data , label , 2 )[0 ]
543- m = np .bincount (label )
552+ C , m = update_cluster_means (data , label , C )
544553
545554 ## Begin phase one: batch reassignments
546555 #-----------------------------------------------------
@@ -574,8 +583,7 @@ def kmeans2(data):
574583 label [lonely ] = i
575584
576585 # Update clusters from which points are taken
577- C = _vq .update_cluster_means (data , label , 2 )[0 ]
578- m = np .bincount (label )
586+ C , m = update_cluster_means (data , label , C )
579587 D = cdist (C , data , metric = 'sqeuclidean' )
580588
581589 # Compute the total sum of distances for the current configuration.
@@ -605,8 +613,7 @@ def kmeans2(data):
605613 break
606614 label = new_label
607615 # update centers
608- C = _vq .update_cluster_means (data , label , 2 )[0 ]
609- m = np .bincount (label )
616+ C , m = update_cluster_means (data , label , C )
610617
611618
612619 #------------------------------------------------------------------
0 commit comments