|
1 | 1 | """ miscellaneous sorting / groupby utilities """
|
2 |
| -from typing import Callable, Optional |
| 2 | +from collections import defaultdict |
| 3 | +from typing import TYPE_CHECKING, Callable, DefaultDict, Iterable, List, Optional, Tuple |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 |
|
|
18 | 19 | import pandas.core.algorithms as algorithms
|
19 | 20 | from pandas.core.construction import extract_array
|
20 | 21 |
|
| 22 | +if TYPE_CHECKING: |
| 23 | + from pandas.core.indexes.base import Index # noqa:F401 |
| 24 | + |
21 | 25 | _INT64_MAX = np.iinfo(np.int64).max
|
22 | 26 |
|
23 | 27 |
|
@@ -409,7 +413,7 @@ def ensure_key_mapped(values, key: Optional[Callable], levels=None):
|
409 | 413 | levels : Optional[List], if values is a MultiIndex, list of levels to
|
410 | 414 | apply the key to.
|
411 | 415 | """
|
412 |
| - from pandas.core.indexes.api import Index |
| 416 | + from pandas.core.indexes.api import Index # noqa:F811 |
413 | 417 |
|
414 | 418 | if not key:
|
415 | 419 | return values
|
@@ -440,36 +444,21 @@ def ensure_key_mapped(values, key: Optional[Callable], levels=None):
|
440 | 444 | return result
|
441 | 445 |
|
442 | 446 |
|
443 |
| -class _KeyMapper: |
444 |
| - """ |
445 |
| - Map compressed group id -> key tuple. |
446 |
| - """ |
447 |
| - |
448 |
| - def __init__(self, comp_ids, ngroups: int, levels, labels): |
449 |
| - self.levels = levels |
450 |
| - self.labels = labels |
451 |
| - self.comp_ids = comp_ids.astype(np.int64) |
452 |
| - |
453 |
| - self.k = len(labels) |
454 |
| - self.tables = [hashtable.Int64HashTable(ngroups) for _ in range(self.k)] |
455 |
| - |
456 |
| - self._populate_tables() |
457 |
| - |
458 |
| - def _populate_tables(self): |
459 |
| - for labs, table in zip(self.labels, self.tables): |
460 |
| - table.map(self.comp_ids, labs.astype(np.int64)) |
461 |
| - |
462 |
| - def get_key(self, comp_id): |
463 |
| - return tuple( |
464 |
| - level[table.get_item(comp_id)] |
465 |
| - for table, level in zip(self.tables, self.levels) |
466 |
| - ) |
467 |
| - |
468 |
| - |
469 |
| -def get_flattened_iterator(comp_ids, ngroups, levels, labels): |
470 |
| - # provide "flattened" iterator for multi-group setting |
471 |
| - mapper = _KeyMapper(comp_ids, ngroups, levels, labels) |
472 |
| - return [mapper.get_key(i) for i in range(ngroups)] |
| 447 | +def get_flattened_list( |
| 448 | + comp_ids: np.ndarray, |
| 449 | + ngroups: int, |
| 450 | + levels: Iterable["Index"], |
| 451 | + labels: Iterable[np.ndarray], |
| 452 | +) -> List[Tuple]: |
| 453 | + """Map compressed group id -> key tuple.""" |
| 454 | + comp_ids = comp_ids.astype(np.int64, copy=False) |
| 455 | + arrays: DefaultDict[int, List[int]] = defaultdict(list) |
| 456 | + for labs, level in zip(labels, levels): |
| 457 | + table = hashtable.Int64HashTable(ngroups) |
| 458 | + table.map(comp_ids, labs.astype(np.int64, copy=False)) |
| 459 | + for i in range(ngroups): |
| 460 | + arrays[i].append(level[table.get_item(i)]) |
| 461 | + return [tuple(array) for array in arrays.values()] |
473 | 462 |
|
474 | 463 |
|
475 | 464 | def get_indexer_dict(label_list, keys):
|
|
0 commit comments