Skip to content

Commit e1dc3c9

Browse files
Parameterize dtypes
1 parent 289ad35 commit e1dc3c9

File tree

1 file changed

+10
-64
lines changed

1 file changed

+10
-64
lines changed

pandas/tests/test_categorical.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -798,75 +798,19 @@ def test_set_categories_inplace(self):
798798
tm.assert_index_equal(cat.categories, pd.Index(['a', 'b', 'c', 'd']))
799799

800800
@pytest.mark.parametrize(
801-
"input1, input2, cat_array",
802-
[
803-
(
804-
np.array([1, 2, 3, 3], dtype=np.dtype('int_')),
805-
np.array([1, 2, 3, 5, 3, 2, 4], dtype=np.dtype('int_')),
806-
np.array([1, 2, 3, 4, 5], dtype=np.dtype('int_'))
807-
),
808-
(
809-
np.array([1, 2, 3, 3], dtype=np.dtype('uint')),
810-
np.array([1, 2, 3, 5, 3, 2, 4], dtype=np.dtype('uint')),
811-
np.array([1, 2, 3, 4, 5], dtype=np.dtype('uint'))
812-
),
813-
(
814-
np.array([1, 2, 3, 3], dtype=np.dtype('float_')),
815-
np.array([1, 2, 3, 5, 3, 2, 4], dtype=np.dtype('float_')),
816-
np.array([1, 2, 3, 4, 5], dtype=np.dtype('float_'))
817-
),
818-
(
819-
np.array(
820-
[1, 2, 3, 3], dtype=np.dtype('unicode_')
821-
),
822-
np.array(
823-
[1, 2, 3, 5, 3, 2, 4], dtype=np.dtype('unicode_')
824-
),
825-
np.array(
826-
[1, 2, 3, 4, 5], dtype=np.dtype('unicode_')
827-
)
828-
),
829-
(
830-
np.array(
831-
[
832-
'2017-01-01 10:00:00', '2017-02-01 10:00:00',
833-
'2017-03-01 10:00:00', '2017-03-01 10:00:00'
834-
],
835-
dtype='datetime64'
836-
),
837-
np.array(
838-
[
839-
'2017-01-01 10:00:00', '2017-02-01 10:00:00',
840-
'2017-03-01 10:00:00', '2017-05-01 10:00:00',
841-
'2017-03-01 10:00:00', '2017-02-01 10:00:00',
842-
'2017-04-01 10:00:00'
843-
],
844-
dtype='datetime64'
845-
),
846-
np.array(
847-
[
848-
'2017-01-01 10:00:00', '2017-02-01 10:00:00',
849-
'2017-03-01 10:00:00', '2017-04-01 10:00:00',
850-
'2017-05-01 10:00:00'
851-
],
852-
dtype='datetime64'
853-
)
854-
),
855-
(
856-
pd.to_timedelta(['1 days', '2 days', '3 days', '3 days'],
857-
unit="D"),
858-
pd.to_timedelta(['1 days', '2 days', '3 days', '5 days',
859-
'3 days', '2 days', '4 days'], unit="D"),
860-
pd.timedelta_range("1 days", periods=5, freq="D")
861-
)
862-
]
801+
"dtype",
802+
["int_", "uint", "float_",
803+
"unicode_", "datetime64[h]", "timedelta64[h]"]
863804
)
864805
@pytest.mark.parametrize("is_ordered", [True, False])
865-
def test_drop_duplicates_non_bool(self, input1, input2,
866-
cat_array, is_ordered):
806+
def test_drop_duplicates_non_bool(self, dtype, is_ordered):
807+
cat_array = np.array([1, 2, 3, 4, 5], dtype=np.dtype(dtype))
808+
867809
# Test case 1
810+
input1 = np.array([1, 2, 3, 3], dtype=np.dtype(dtype))
868811
tc1 = Series(Categorical(input1, categories=cat_array,
869812
ordered=is_ordered))
813+
870814
expected = Series([False, False, False, True])
871815
tm.assert_series_equal(tc1.duplicated(), expected)
872816
tm.assert_series_equal(tc1.drop_duplicates(), tc1[~expected])
@@ -890,8 +834,10 @@ def test_drop_duplicates_non_bool(self, input1, input2,
890834
tm.assert_series_equal(sc, tc1[~expected])
891835

892836
# Test case 2
837+
input2 = np.array([1, 2, 3, 5, 3, 2, 4], dtype=np.dtype(dtype))
893838
tc2 = Series(Categorical(input2, categories=cat_array,
894839
ordered=is_ordered))
840+
895841
expected = Series([False, False, False, False, True, True, False])
896842
tm.assert_series_equal(tc2.duplicated(), expected)
897843
tm.assert_series_equal(tc2.drop_duplicates(), tc2[~expected])

0 commit comments

Comments
 (0)