Skip to content

Commit 4d0612e

Browse files
tptopper-123
authored andcommitted
make Int8/16/32Engine work with Int64HashTable
1 parent 91ee55d commit 4d0612e

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

pandas/_libs/index.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import cython
66
import numpy as np
77
cimport numpy as cnp
88
from numpy cimport (ndarray, float64_t, int32_t,
9-
int64_t, uint8_t, uint64_t, intp_t,
9+
int8_t, int16_t, int32_t, int64_t,
10+
uint8_t, uint64_t,
11+
intp_t,
1012
# Note: NPY_DATETIME, NPY_TIMEDELTA are only available
1113
# for cimport in cython>=0.27.3
1214
NPY_DATETIME, NPY_TIMEDELTA)
@@ -242,6 +244,8 @@ cdef class IndexEngine:
242244
if not self.is_mapping_populated:
243245

244246
values = self._get_index_values()
247+
if values.dtype in {'int8', 'int16', 'int32'}:
248+
values = algos.ensure_int64(values)
245249
self.mapping = self._make_hash_table(len(values))
246250
self._call_map_locations(values)
247251

pandas/_libs/index_class_helper.pxi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ cdef class {{name}}Engine(IndexEngine):
4040
cdef _make_hash_table(self, n):
4141
{{if name == 'Object'}}
4242
return _hash.PyObjectHashTable(n)
43+
{{elif name in {'Int8', 'Int16', 'Int32'} }}
44+
return _hash.Int64HashTable(n)
4345
{{else}}
4446
return _hash.{{name}}HashTable(n)
4547
{{endif}}

pandas/core/indexes/category.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -448,18 +448,12 @@ def get_loc(self, key, method=None):
448448
>>> non_monotonic_index.get_loc('b')
449449
array([False, True, False, True], dtype=bool)
450450
"""
451-
codes = self.categories.get_loc(key)
452-
if (codes == -1):
453-
raise KeyError(key)
454-
455-
if self.is_monotonic_increasing and not self.is_unique:
456-
if codes not in self._engine:
457-
raise KeyError(key)
458-
codes = self.codes.dtype.type(codes)
459-
lhs = self.codes.searchsorted(codes, side='left')
460-
rhs = self.codes.searchsorted(codes, side='right')
461-
return slice(lhs, rhs)
462-
return self._engine.get_loc(codes)
451+
code = self.categories.get_loc(key)
452+
453+
# dtype must be same as dtype for self.codes else searchsorted is slow
454+
code = self.codes.dtype.type(code)
455+
456+
return self._engine.get_loc(code)
463457

464458
def get_value(self, series, key):
465459
"""

0 commit comments

Comments
 (0)