Skip to content

Commit 255bfab

Browse files
committed
Consistent CategoricalDtype use in Categorical init
Get a valid instance of `CategoricalDtype` as early as possible, and use that throughout.
1 parent bb25f3b commit 255bfab

File tree

4 files changed

+112
-15
lines changed

4 files changed

+112
-15
lines changed

pandas/core/categorical.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,21 @@ class Categorical(PandasObject):
235235
def __init__(self, values, categories=None, ordered=None, dtype=None,
236236
fastpath=False):
237237

238+
# Ways of specifying the dtype (prioritized ordered)
239+
# 1. dtype is a CategoricalDtype
240+
# a.) with known categories, use dtype.categories
241+
# b.) else with Categorical values, use values.dtype
242+
# c.) else, infer from values
243+
# d.) specifying dtype=CategoricalDtype and categories is an error
244+
# 2. dtype is a string 'category'
245+
# a.) use categories, ordered
246+
# b.) use values.dtype
247+
# c.) infer from values
248+
# 3. dtype is None
249+
# a.) use categories, ordered
250+
# b.) use values.dtype
251+
# c.) infer from values
252+
238253
if dtype is not None:
239254
if isinstance(dtype, compat.string_types):
240255
if dtype == 'category':
@@ -248,20 +263,24 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
248263
categories = dtype.categories
249264
ordered = dtype.ordered
250265

251-
if ordered is None:
252-
ordered = False
266+
elif is_categorical(values):
267+
dtype = values.dtype._from_categorical_dtype(values.dtype,
268+
categories, ordered)
269+
else:
270+
dtype = CategoricalDtype(categories, ordered)
271+
272+
# At this point, dtype is always a CategoricalDtype
273+
# if dtype.categories is None, we are inferring
253274

254275
if fastpath:
255-
if dtype is None:
256-
dtype = CategoricalDtype(categories, ordered)
257276
self._codes = coerce_indexer_dtype(values, categories)
258277
self._dtype = dtype
259278
return
260279

261280
# sanitize input
262281
if is_categorical_dtype(values):
263282

264-
# we are either a Series, CategoricalIndex
283+
# we are either a Series or a CategoricalIndex
265284
if isinstance(values, (ABCSeries, ABCCategoricalIndex)):
266285
values = values._values
267286

@@ -272,6 +291,7 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
272291
values = values.get_values()
273292

274293
elif isinstance(values, (ABCIndexClass, ABCSeries)):
294+
# we'll do inference later
275295
pass
276296

277297
else:
@@ -289,12 +309,12 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
289309
# "object" dtype to prevent this. In the end objects will be
290310
# casted to int/... in the category assignment step.
291311
if len(values) == 0 or isna(values).any():
292-
dtype = 'object'
312+
sanitize_dtype = 'object'
293313
else:
294-
dtype = None
295-
values = _sanitize_array(values, None, dtype=dtype)
314+
sanitize_dtype = None
315+
values = _sanitize_array(values, None, dtype=sanitize_dtype)
296316

297-
if categories is None:
317+
if dtype.categories is None:
298318
try:
299319
codes, categories = factorize(values, sort=True)
300320
except TypeError:
@@ -311,7 +331,8 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
311331
raise NotImplementedError("> 1 ndim Categorical are not "
312332
"supported at this time")
313333

314-
if dtype is None or isinstance(dtype, str):
334+
if dtype.categories is None:
335+
# we're inferring from values
315336
dtype = CategoricalDtype(categories, ordered)
316337

317338
else:
@@ -322,11 +343,6 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
322343
# - the new one, where each value is also in the categories array
323344
# (or np.nan)
324345

325-
# make sure that we always have the same type here, no matter what
326-
# we get passed in
327-
if dtype is None or isinstance(dtype, str):
328-
dtype = CategoricalDtype(categories, ordered)
329-
330346
codes = _get_codes_for_values(values, dtype.categories)
331347

332348
# TODO: check for old style usage. These warnings should be removes

pandas/core/dtypes/dtypes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,22 @@ def _from_fastpath(cls, categories=None, ordered=False):
160160
self._finalize(categories, ordered, fastpath=True)
161161
return self
162162

163+
@classmethod
164+
def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
165+
if categories is ordered is None:
166+
return dtype
167+
if categories is None:
168+
categories = dtype.categories
169+
if ordered is None:
170+
ordered = dtype.ordered
171+
return cls(categories, ordered)
172+
163173
def _finalize(self, categories, ordered, fastpath=False):
164174
from pandas.core.indexes.base import Index
165175

176+
if ordered is None:
177+
ordered = False
178+
166179
if categories is not None:
167180
categories = Index(categories, tupleize_cols=False)
168181
# validation

pandas/tests/dtypes/test_dtypes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,30 @@ def test_mixed(self):
622622
a = CategoricalDtype(['a', 'b', 1, 2])
623623
b = CategoricalDtype(['a', 'b', '1', '2'])
624624
assert hash(a) != hash(b)
625+
626+
def test_from_categorical_dtype_identity(self):
627+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
628+
# Identity test for no changes
629+
c2 = CategoricalDtype._from_categorical_dtype(c1)
630+
assert c2 is c1
631+
632+
def test_from_categorical_dtype_categories(self):
633+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
634+
# override categories
635+
result = CategoricalDtype._from_categorical_dtype(
636+
c1, categories=[2, 3])
637+
assert result == CategoricalDtype([2, 3], ordered=True)
638+
639+
def test_from_categorical_dtype_ordered(self):
640+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
641+
# override ordered
642+
result = CategoricalDtype._from_categorical_dtype(
643+
c1, ordered=False)
644+
assert result == CategoricalDtype([1, 2, 3], ordered=False)
645+
646+
def test_from_categorical_dtype_both(self):
647+
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
648+
# override ordered
649+
result = CategoricalDtype._from_categorical_dtype(
650+
c1, categories=[1, 2], ordered=False)
651+
assert result == CategoricalDtype([1, 2], ordered=False)

pandas/tests/test_categorical.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,37 @@ def test_constructor_str_unknown(self):
488488
with tm.assert_raises_regex(ValueError, "Unknown `dtype`"):
489489
Categorical([1, 2], dtype="foo")
490490

491+
def test_constructor_from_categorical_with_dtype(self):
492+
dtype = CategoricalDtype(['a', 'b', 'c'], ordered=True)
493+
values = Categorical(['a', 'b', 'd'])
494+
result = Categorical(values, dtype=dtype)
495+
# We use dtype.categories, not values.categories
496+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
497+
ordered=True)
498+
tm.assert_categorical_equal(result, expected)
499+
500+
def test_constructor_from_categorical_with_unknown_dtype(self):
501+
dtype = CategoricalDtype(None, ordered=True)
502+
values = Categorical(['a', 'b', 'd'])
503+
result = Categorical(values, dtype=dtype)
504+
# We use values.categories, not dtype.categories
505+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'd'],
506+
ordered=True)
507+
tm.assert_categorical_equal(result, expected)
508+
509+
def test_contructor_from_categorical_string(self):
510+
values = Categorical(['a', 'b', 'd'])
511+
# use categories, ordered
512+
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True,
513+
dtype='category')
514+
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
515+
ordered=True)
516+
tm.assert_categorical_equal(result, expected)
517+
518+
# No string
519+
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True)
520+
tm.assert_categorical_equal(result, expected)
521+
491522
def test_from_codes(self):
492523

493524
# too few categories
@@ -932,6 +963,16 @@ def test_set_dtype_nans(self):
932963
tm.assert_numpy_array_equal(result.codes, np.array([0, -1, -1],
933964
dtype='int8'))
934965

966+
def test_set_categories(self):
967+
cat = Categorical(['a', 'b', 'c'], categories=['a', 'b', 'c', 'd'])
968+
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'])
969+
expected = Categorical(['a', 'b', 'c'], categories=list('abcde'))
970+
tm.assert_categorical_equal(result, expected)
971+
972+
# fastpath
973+
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'], fastpath=True)
974+
tm.assert_categorical_equal(result, expected)
975+
935976
@pytest.mark.parametrize('values, categories, new_categories', [
936977
# No NaNs, same cats, same order
937978
(['a', 'b', 'a'], ['a', 'b'], ['a', 'b'],),

0 commit comments

Comments
 (0)