-
-
Notifications
You must be signed in to change notification settings - Fork 19k
ENH/API: ExtensionArray.factorize #20361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
0ec3600
4707273
b61fb8d
44b6d72
5be3917
e474c20
c2578c3
b73e303
0db9e97
baf624c
ce92f7b
8cbfc36
425fb2a
7bbe796
31ed4c9
434df7d
505ad44
77a10b6
b59656f
201e029
9b0c2a9
eb19488
cbfee1a
35a8977
7efece2
ef8e6cb
dd3bf1d
6a6034f
5c758aa
5526398
cd5c2db
d5e8198
30941cb
3574273
c776133
2a79315
6ca65f8
bbedd8c
96ecab7
1010417
c288d67
55c9e31
163bfa3
872c24a
3c18428
703ab8a
ab32e0f
62fa538
28fad50
8580754
cf14ee1
8141131
a23d451
b25f3d4
dfcda85
eaff342
c05c807
e786253
465d458
6f8036e
bca4cdf
69c3ea2
fa8e221
c06da3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,6 +248,41 @@ def unique(self): | |
uniques = unique(self.astype(object)) | ||
return self._constructor_from_sequence(uniques) | ||
|
||
def factorize(self, na_sentinel=-1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you need a sort=False arg here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want / need that? It complicates the implementation a bit. Any idea what it's actually used for? |
||
"""Encode the extension array as an enumerated type. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, need to merge this one too maybe. Will have to check on import order... |
||
|
||
Parameters | ||
---------- | ||
na_sentinel : int, default -1 | ||
Value to use in the `labels` array to indicate missing values. | ||
|
||
Returns | ||
------- | ||
labels : ndarray | ||
An interger NumPy array that's an indexer into the original | ||
ExtensionArray | ||
uniques : ExtensionArray | ||
An ExtensionArray containing the unique values of `self`. | ||
|
||
See Also | ||
-------- | ||
pandas.factorize : top-level factorize method that dispatches here. | ||
|
||
Notes | ||
----- | ||
:meth:`pandas.factorize` offers a `sort` keyword as well. | ||
""" | ||
from pandas.core.algorithms import _factorize_array | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're OK with using this private API here? Because an extension authors might want to copy paste this method and change the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quite similar to this. https://github.com/ContinuumIO/cyberpandas/blob/468644bcbdc9320a1a33b0df393d4fa4bef57dd7/cyberpandas/base.py#L72 In that case I think going to object dtypes is unavoidable, since there's no easy way to factorize a 2-D array, and I didn't want to write a new hashtable implementation :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. w.r.t. using We might consider making it public / semi-public (keep the |
||
|
||
mask = self.isna() | ||
arr = self.astype(object) | ||
arr[mask] = np.nan | ||
|
||
|
||
labels, uniques = _factorize_array(arr, check_nulls=True, | ||
|
||
na_sentinel=na_sentinel) | ||
uniques = self._constructor_from_sequence(uniques) | ||
return labels, uniques | ||
|
||
# ------------------------------------------------------------------------ | ||
# Indexing methods | ||
# ------------------------------------------------------------------------ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import numpy as np | ||
|
||
import pandas as pd | ||
import pandas.util.testing as tm | ||
|
||
from .base import BaseExtensionTests | ||
|
||
|
@@ -42,3 +43,22 @@ def test_unique(self, data, box, method): | |
assert len(result) == 1 | ||
assert isinstance(result, type(data)) | ||
assert result[0] == duplicated[0] | ||
|
||
@pytest.mark.parametrize('na_sentinel', [-1, -2]) | ||
def test_factorize(self, data_for_grouping, na_sentinel): | ||
labels, uniques = pd.factorize(data_for_grouping, | ||
na_sentinel=na_sentinel) | ||
expected_labels = np.array([0, 0, na_sentinel, | ||
na_sentinel, 1, 1, 0, 2], | ||
dtype='int64') | ||
expected_uniques = data_for_grouping.take([0, 4, 7]) | ||
|
||
tm.assert_numpy_array_equal(labels, expected_labels) | ||
self.assert_extension_array_equal(uniques, expected_uniques) | ||
|
||
def test_factorize_equivalence(self, data_for_grouping): | ||
l1, u1 = pd.factorize(data_for_grouping) | ||
l2, u2 = pd.factorize(data_for_grouping) | ||
|
||
|
||
tm.assert_numpy_array_equal(l1, l2) | ||
self.assert_extension_array_equal(u1, u2) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
import numpy as np | ||
|
||
import pandas as pd | ||
from pandas.core.dtypes.base import ExtensionDtype | ||
from pandas.core.arrays import ExtensionArray | ||
|
||
|
@@ -104,6 +105,21 @@ def _concat_same_type(cls, to_concat): | |
data = list(itertools.chain.from_iterable([x.data for x in to_concat])) | ||
return cls(data) | ||
|
||
def factorize(self, na_sentinel=-1): | ||
frozen = tuple(tuple(x.items()) for x in self) | ||
labels, uniques = pd.factorize(frozen) | ||
|
||
# fixup NA | ||
|
||
if self.isna().any(): | ||
na_code = labels[self.isna()][0] | ||
|
||
labels[labels == na_code] = na_sentinel | ||
labels[labels > na_code] -= 1 | ||
|
||
uniques = JSONArray([collections.UserDict(x) | ||
for x in uniques if x != ()]) | ||
return labels, uniques | ||
|
||
|
||
def make_data(): | ||
# TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import numpy as np | ||
|
||
import pandas as pd | ||
from pandas.core.arrays.base import ExtensionArray | ||
from pandas.core.dtypes.missing import array_equivalent | ||
from pandas.core.dtypes.common import ( | ||
is_datetimelike_v_numeric, | ||
|
@@ -1083,6 +1084,32 @@ def _raise(left, right, err_msg): | |
return True | ||
|
||
|
||
def assert_extension_array_equal(left, right): | ||
"""Check that left and right ExtensionArrays are equal. | ||
|
||
Parameters | ||
---------- | ||
left, right : ExtensionArray | ||
The two arrays to compare | ||
|
||
Notes | ||
----- | ||
Missing values are checked separately from valid values. | ||
A mask of missing values is computed for each and checked to match. | ||
The remaining all-valid values are cast to object dtype and checked. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this needed only now? (wasn't the missing values the reason you added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'll see if any old tests can make use of this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any other cases where we could apply The closes would be the |
||
assert isinstance(left, ExtensionArray) | ||
assert left.dtype == right.dtype | ||
left_na = left.isna() | ||
right_na = right.isna() | ||
assert_numpy_array_equal(left_na, right_na) | ||
|
||
left_valid = left[~left_na].astype(object) | ||
right_valid = right[~right_na].astype(object) | ||
|
||
assert_numpy_array_equal(left_valid, right_valid) | ||
|
||
|
||
# This could be refactored to use the NDFrame.equals method | ||
def assert_series_equal(left, right, check_dtype=True, | ||
check_index_type='equiv', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this was a bug in #19938 where I forgot to pass this through. It's covered by our extension tests.