-
-
Notifications
You must be signed in to change notification settings - Fork 19k
Parametrized NA sentinel for factorize #20473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
872c24a
3c18428
703ab8a
ab32e0f
62fa538
28fad50
8580754
cf14ee1
8141131
a23d451
b25f3d4
dfcda85
eaff342
c05c807
e786253
465d458
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -250,13 +250,13 @@ cdef class HashTable: | |
|
||
{{py: | ||
|
||
# name, dtype, null_condition, float_group | ||
dtypes = [('Float64', 'float64', 'val != val', True), | ||
('UInt64', 'uint64', 'False', False), | ||
('Int64', 'int64', 'val == iNaT', False)] | ||
# name, dtype, null_condition, float_group, default_na_value | ||
dtypes = [('Float64', 'float64', 'val != val', True, 'nan'), | ||
('UInt64', 'uint64', 'False', False, 0), | ||
('Int64', 'int64', 'val == iNaT', False, 'iNaT')] | ||
|
||
def get_dispatch(dtypes): | ||
for (name, dtype, null_condition, float_group) in dtypes: | ||
for (name, dtype, null_condition, float_group, default_na_value) in dtypes: | ||
unique_template = """\ | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
|
@@ -300,16 +300,19 @@ def get_dispatch(dtypes): | |
|
||
unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group) | ||
|
||
yield (name, dtype, null_condition, float_group, unique_template) | ||
yield (name, dtype, null_condition, float_group, default_na_value, unique_template) | ||
}} | ||
|
||
|
||
{{for name, dtype, null_condition, float_group, unique_template in get_dispatch(dtypes)}} | ||
{{for name, dtype, null_condition, float_group, default_na_value, unique_template in get_dispatch(dtypes)}} | ||
|
||
cdef class {{name}}HashTable(HashTable): | ||
|
||
def __cinit__(self, size_hint=1): | ||
def __cinit__(self, size_hint=1, {{dtype}}_t na_value={{default_na_value}}, | ||
bint use_na_value=False): | ||
|
||
self.table = kh_init_{{dtype}}() | ||
self.na_value = na_value | ||
self.use_na_value = use_na_value | ||
|
||
if size_hint is not None: | ||
kh_resize_{{dtype}}(self.table, size_hint) | ||
|
||
|
@@ -414,18 +417,22 @@ cdef class {{name}}HashTable(HashTable): | |
int64_t[:] labels | ||
Py_ssize_t idx, count = count_prior | ||
int ret = 0 | ||
{{dtype}}_t val | ||
{{dtype}}_t val, na_value | ||
khiter_t k | ||
{{name}}VectorData *ud | ||
bint use_na_value | ||
|
||
labels = np.empty(n, dtype=np.int64) | ||
ud = uniques.data | ||
na_value = self.na_value | ||
use_na_value = self.use_na_value | ||
|
||
with nogil: | ||
for i in range(n): | ||
val = values[i] | ||
|
||
if check_null and {{null_condition}}: | ||
if ((check_null and {{null_condition}}) or | ||
|
||
(use_na_value and val == na_value)): | ||
labels[i] = na_sentinel | ||
continue | ||
|
||
|
@@ -519,8 +526,11 @@ cdef class StringHashTable(HashTable): | |
# or a sentinel np.nan / None missing value | ||
na_string_sentinel = '__nan__' | ||
|
||
def __init__(self, int size_hint=1): | ||
def __init__(self, int size_hint=1, object na_value=na_string_sentinel, | ||
bint use_na_value=False): | ||
self.table = kh_init_str() | ||
self.na_value = na_value | ||
self.use_na_value = use_na_value | ||
if size_hint is not None: | ||
kh_resize_str(self.table, size_hint) | ||
|
||
|
@@ -706,18 +716,23 @@ cdef class StringHashTable(HashTable): | |
char *v | ||
char **vecs | ||
khiter_t k | ||
bint use_na_value | ||
|
||
# these by-definition *must* be strings | ||
labels = np.zeros(n, dtype=np.int64) | ||
uindexer = np.empty(n, dtype=np.int64) | ||
|
||
na_value = self.na_value | ||
use_na_value = self.use_na_value | ||
|
||
# pre-filter out missing | ||
# and assign pointers | ||
vecs = <char **> malloc(n * sizeof(char *)) | ||
for i in range(n): | ||
val = values[i] | ||
|
||
if PyUnicode_Check(val) or PyString_Check(val): | ||
if ((PyUnicode_Check(val) or PyString_Check(val)) and | ||
not (use_na_value and val == na_value)): | ||
v = util.get_c_string(val) | ||
vecs[i] = v | ||
else: | ||
|
@@ -753,8 +768,11 @@ na_sentinel = object | |
|
||
cdef class PyObjectHashTable(HashTable): | ||
|
||
def __init__(self, size_hint=1): | ||
def __init__(self, size_hint=1, object na_value=na_sentinel, | ||
bint use_na_value=False): | ||
self.table = kh_init_pymap() | ||
self.na_value = na_value | ||
self.use_na_value = use_na_value | ||
kh_resize_pymap(self.table, size_hint) | ||
|
||
def __dealloc__(self): | ||
|
@@ -876,14 +894,18 @@ cdef class PyObjectHashTable(HashTable): | |
int ret = 0 | ||
object val | ||
khiter_t k | ||
bint use_na_value | ||
|
||
labels = np.empty(n, dtype=np.int64) | ||
na_value = self.na_value | ||
use_na_value = self.use_na_value | ||
|
||
for i in range(n): | ||
val = values[i] | ||
hash(val) | ||
|
||
if check_null and val != val or val is None: | ||
if ((check_null and val != val or val is None) or | ||
(use_na_value and val == na_value)): | ||
labels[i] = na_sentinel | ||
continue | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -435,7 +435,8 @@ def isin(comps, values): | |
return f(comps, values) | ||
|
||
|
||
def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None): | ||
def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None, | ||
na_value=None): | ||
"""Factorize an array-like to labels and uniques. | ||
|
||
This doesn't do any coercion of types or unboxing before factorization. | ||
|
@@ -455,7 +456,13 @@ def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None): | |
""" | ||
(hash_klass, vec_klass), values = _get_data_algo(values, _hashtables) | ||
|
||
table = hash_klass(size_hint or len(values)) | ||
use_na_value = na_value is not None | ||
kwargs = dict(use_na_value=use_na_value) | ||
|
||
if use_na_value: | ||
kwargs['na_value'] = na_value | ||
|
||
table = hash_klass(size_hint or len(values), **kwargs) | ||
uniques = vec_klass() | ||
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls) | ||
|
||
|
@@ -465,7 +472,8 @@ def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None): | |
|
||
|
||
@deprecate_kwarg(old_arg_name='order', new_arg_name=None) | ||
def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None): | ||
def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None, | ||
na_value=None): | ||
""" | ||
Encode input values as an enumerated type or categorical variable | ||
|
||
|
@@ -479,6 +487,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None): | |
na_sentinel : int, default -1 | ||
Value to mark "not found" | ||
size_hint : hint to the hashtable sizer | ||
na_value : object, optional | ||
A value in `values` to consider missing. | ||
|
||
Returns | ||
------- | ||
|
@@ -509,9 +519,11 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None): | |
else: | ||
values, dtype, _ = _ensure_data(values) | ||
check_nulls = not is_integer_dtype(original) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why |
||
labels, uniques = _factorize_array(values, check_nulls, | ||
na_sentinel=na_sentinel, | ||
size_hint=size_hint) | ||
size_hint=size_hint, | ||
na_value=na_value) | ||
|
||
if sort and len(uniques) > 0: | ||
from pandas.core.sorting import safe_sort | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize you didn't set this, but is the
uint64
track for bool hashing? If so, any reason why we don't useuint8
to save space?