@@ -373,7 +373,7 @@ def _get_group_levels(self, mask, obs_ids):
373
373
374
374
name_list = []
375
375
for ping , labels in zip (self .groupings , recons_labels ):
376
- labels = _check_platform_int (labels )
376
+ labels = _ensure_platform_int (labels )
377
377
name_list .append ((ping .name , ping .group_index .take (labels )))
378
378
379
379
return name_list
@@ -1327,6 +1327,8 @@ def cython_aggregate(values, group_index, ngroups, how='add'):
1327
1327
def _compress_group_index (group_index , sort = True ):
1328
1328
uniques = []
1329
1329
table = lib .Int64HashTable (len (group_index ))
1330
+
1331
+ group_index = _ensure_int64 (group_index )
1330
1332
comp_ids = table .get_labels_groupby (group_index , uniques )
1331
1333
max_group = len (uniques )
1332
1334
@@ -1356,11 +1358,16 @@ def _group_labels(values):
1356
1358
values = values .astype ('O' )
1357
1359
return lib .group_labels (values )
1358
1360
1359
- def _check_platform_int (labels ):
1361
+ def _ensure_platform_int (labels ):
1360
1362
if labels .dtype != np .int_ :
1361
1363
labels = labels .astype (np .int_ )
1362
1364
return labels
1363
1365
1366
+ def _ensure_int64 (labels ):
1367
+ if labels .dtype != np .int64 :
1368
+ labels = labels .astype (np .int64 )
1369
+ return labels
1370
+
1364
1371
def sort_group_labels (ids , labels , counts ):
1365
1372
n = len (ids )
1366
1373
0 commit comments