Skip to content

Commit 8e22443

Browse files
thomasjpfanglemaitreogrisel
committed
FIX Fixes encoders for string dtypes (scikit-learn#15763)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent 954b9bc commit 8e22443

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

doc/whats_new/v0.24.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,10 @@ Changelog
662662
encoded as all zeros. :pr:`14982` by
663663
:user:`Kevin Winata <kwinata>`.
664664

665+
- |Fix| Fix incorrect encoding when using unicode string dtypes in
666+
:class:`preprocessing.OneHotEncoder` and
667+
:class:`preprocessing.OrdinalEncoder`. :pr:`15763` by `Thomas Fan`_.
668+
665669
:mod:`sklearn.svm`
666670
..................
667671

sklearn/preprocessing/_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _fit(self, X, handle_unknown='error', force_all_finite=True):
9090
cats = _unique(Xi)
9191
else:
9292
cats = np.array(self.categories[i], dtype=Xi.dtype)
93-
if Xi.dtype != object:
93+
if Xi.dtype.kind not in 'OU':
9494
sorted_cats = np.sort(cats)
9595
error_msg = ("Unsorted categories are not "
9696
"supported for numerical categories")

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,33 @@ def test_encoders_has_categorical_tags(Encoder):
830830
assert 'categorical' in Encoder()._get_tags()['X_types']
831831

832832

833+
@pytest.mark.parametrize('input_dtype', ['O', 'U'])
834+
@pytest.mark.parametrize('category_dtype', ['O', 'U'])
835+
@pytest.mark.parametrize('array_type', ['list', 'array', 'dataframe'])
836+
def test_encoders_unicode_categories(input_dtype, category_dtype, array_type):
837+
"""Check that encoding work with string and object dtypes.
838+
Non-regression test for:
839+
https://github.com/scikit-learn/scikit-learn/issues/15616
840+
https://github.com/scikit-learn/scikit-learn/issues/15726
841+
"""
842+
843+
X = np.array([['b'], ['a']], dtype=input_dtype)
844+
categories = [np.array(['b', 'a'], dtype=category_dtype)]
845+
ohe = OneHotEncoder(categories=categories, sparse=False).fit(X)
846+
847+
X_test = _convert_container([['a'], ['a'], ['b'], ['a']], array_type)
848+
X_trans = ohe.transform(X_test)
849+
850+
expected = np.array([[0, 1], [0, 1], [1, 0], [0, 1]])
851+
assert_allclose(X_trans, expected)
852+
853+
oe = OrdinalEncoder(categories=categories).fit(X)
854+
X_trans = oe.transform(X_test)
855+
856+
expected = np.array([[1], [1], [0], [1]])
857+
assert_array_equal(X_trans, expected)
858+
859+
833860
@pytest.mark.parametrize("missing_value", [np.nan, None])
834861
def test_ohe_missing_values_get_feature_names(missing_value):
835862
# encoder with missing values with object dtypes

sklearn/utils/_encode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _encode(values, *, uniques, check_unknown=True):
173173
encoded : ndarray
174174
Encoded values
175175
"""
176-
if values.dtype == object:
176+
if values.dtype.kind in 'OU':
177177
try:
178178
return _map_to_integer(values, uniques)
179179
except KeyError as e:

0 commit comments

Comments
 (0)