@@ -65,17 +65,6 @@ class CoercionBase:
65
65
def method (self ):
66
66
raise NotImplementedError (self )
67
67
68
- def _assert (self , left , right , dtype ):
69
- # explicitly check dtype to avoid any unexpected result
70
- if isinstance (left , pd .Series ):
71
- tm .assert_series_equal (left , right )
72
- elif isinstance (left , pd .Index ):
73
- tm .assert_index_equal (left , right )
74
- else :
75
- raise NotImplementedError
76
- assert left .dtype == dtype
77
- assert right .dtype == dtype
78
-
79
68
80
69
class TestSetitemCoercion (CoercionBase ):
81
70
@@ -91,6 +80,7 @@ def _assert_setitem_series_conversion(
91
80
# check dtype explicitly for sure
92
81
assert temp .dtype == expected_dtype
93
82
83
+ # FIXME: dont leave commented-out
94
84
# .loc works different rule, temporary disable
95
85
# temp = original_series.copy()
96
86
# temp.loc[1] = loc_value
@@ -565,7 +555,8 @@ def _assert_where_conversion(
565
555
""" test coercion triggered by where """
566
556
target = original .copy ()
567
557
res = target .where (cond , values )
568
- self ._assert (res , expected , expected_dtype )
558
+ tm .assert_equal (res , expected )
559
+ assert res .dtype == expected_dtype
569
560
570
561
@pytest .mark .parametrize (
571
562
"fill_val,exp_dtype" ,
@@ -588,7 +579,7 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype):
588
579
if fill_val is True :
589
580
values = klass ([True , False , True , True ])
590
581
else :
591
- values = klass (fill_val * x for x in [5 , 6 , 7 , 8 ])
582
+ values = klass (x * fill_val for x in [5 , 6 , 7 , 8 ])
592
583
593
584
exp = klass (["a" , values [1 ], "c" , values [3 ]])
594
585
self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
@@ -647,38 +638,40 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype):
647
638
],
648
639
)
649
640
def test_where_series_complex128 (self , fill_val , exp_dtype ):
650
- obj = pd .Series ([1 + 1j , 2 + 2j , 3 + 3j , 4 + 4j ])
641
+ klass = pd .Series
642
+ obj = klass ([1 + 1j , 2 + 2j , 3 + 3j , 4 + 4j ])
651
643
assert obj .dtype == np .complex128
652
- cond = pd . Series ([True , False , True , False ])
644
+ cond = klass ([True , False , True , False ])
653
645
654
- exp = pd . Series ([1 + 1j , fill_val , 3 + 3j , fill_val ])
646
+ exp = klass ([1 + 1j , fill_val , 3 + 3j , fill_val ])
655
647
self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
656
648
657
649
if fill_val is True :
658
- values = pd . Series ([True , False , True , True ])
650
+ values = klass ([True , False , True , True ])
659
651
else :
660
- values = pd . Series (x * fill_val for x in [5 , 6 , 7 , 8 ])
661
- exp = pd . Series ([1 + 1j , values [1 ], 3 + 3j , values [3 ]])
652
+ values = klass (x * fill_val for x in [5 , 6 , 7 , 8 ])
653
+ exp = klass ([1 + 1j , values [1 ], 3 + 3j , values [3 ]])
662
654
self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
663
655
664
656
@pytest .mark .parametrize (
665
657
"fill_val,exp_dtype" ,
666
658
[(1 , object ), (1.1 , object ), (1 + 1j , object ), (True , np .bool_ )],
667
659
)
668
660
def test_where_series_bool (self , fill_val , exp_dtype ):
661
+ klass = pd .Series
669
662
670
- obj = pd . Series ([True , False , True , False ])
663
+ obj = klass ([True , False , True , False ])
671
664
assert obj .dtype == np .bool_
672
- cond = pd . Series ([True , False , True , False ])
665
+ cond = klass ([True , False , True , False ])
673
666
674
- exp = pd . Series ([True , fill_val , True , fill_val ])
667
+ exp = klass ([True , fill_val , True , fill_val ])
675
668
self ._assert_where_conversion (obj , cond , fill_val , exp , exp_dtype )
676
669
677
670
if fill_val is True :
678
- values = pd . Series ([True , False , True , True ])
671
+ values = klass ([True , False , True , True ])
679
672
else :
680
- values = pd . Series (x * fill_val for x in [5 , 6 , 7 , 8 ])
681
- exp = pd . Series ([True , values [1 ], True , values [3 ]])
673
+ values = klass (x * fill_val for x in [5 , 6 , 7 , 8 ])
674
+ exp = klass ([True , values [1 ], True , values [3 ]])
682
675
self ._assert_where_conversion (obj , cond , values , exp , exp_dtype )
683
676
684
677
@pytest .mark .parametrize (
@@ -871,7 +864,8 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
871
864
""" test coercion triggered by fillna """
872
865
target = original .copy ()
873
866
res = target .fillna (value )
874
- self ._assert (res , expected , expected_dtype )
867
+ tm .assert_equal (res , expected )
868
+ assert res .dtype == expected_dtype
875
869
876
870
@pytest .mark .parametrize (
877
871
"fill_val, fill_dtype" ,
@@ -1040,10 +1034,12 @@ class TestReplaceSeriesCoercion(CoercionBase):
1040
1034
1041
1035
rep ["timedelta64[ns]" ] = [pd .Timedelta ("1 day" ), pd .Timedelta ("2 day" )]
1042
1036
1043
- @pytest .mark .parametrize ("how" , ["dict" , "series" ])
1044
- @pytest .mark .parametrize (
1045
- "to_key" ,
1046
- [
1037
+ @pytest .fixture (params = ["dict" , "series" ])
1038
+ def how (self , request ):
1039
+ return request .param
1040
+
1041
+ @pytest .fixture (
1042
+ params = [
1047
1043
"object" ,
1048
1044
"int64" ,
1049
1045
"float64" ,
@@ -1053,34 +1049,52 @@ class TestReplaceSeriesCoercion(CoercionBase):
1053
1049
"datetime64[ns, UTC]" ,
1054
1050
"datetime64[ns, US/Eastern]" ,
1055
1051
"timedelta64[ns]" ,
1056
- ],
1057
- ids = [
1052
+ ]
1053
+ )
1054
+ def from_key (self , request ):
1055
+ return request .param
1056
+
1057
+ @pytest .fixture (
1058
+ params = [
1058
1059
"object" ,
1059
1060
"int64" ,
1060
1061
"float64" ,
1061
1062
"complex128" ,
1062
1063
"bool" ,
1063
- "datetime64" ,
1064
- "datetime64tz " ,
1065
- "datetime64tz " ,
1066
- "timedelta64" ,
1064
+ "datetime64[ns] " ,
1065
+ "datetime64[ns, UTC] " ,
1066
+ "datetime64[ns, US/Eastern] " ,
1067
+ "timedelta64[ns] " ,
1067
1068
],
1068
- )
1069
- @pytest .mark .parametrize (
1070
- "from_key" ,
1071
- [
1069
+ ids = [
1072
1070
"object" ,
1073
1071
"int64" ,
1074
1072
"float64" ,
1075
1073
"complex128" ,
1076
1074
"bool" ,
1077
- "datetime64[ns] " ,
1078
- "datetime64[ns, UTC] " ,
1079
- "datetime64[ns, US/Eastern] " ,
1080
- "timedelta64[ns] " ,
1075
+ "datetime64" ,
1076
+ "datetime64tz " ,
1077
+ "datetime64tz " ,
1078
+ "timedelta64" ,
1081
1079
],
1082
1080
)
1083
- def test_replace_series (self , how , to_key , from_key ):
1081
+ def to_key (self , request ):
1082
+ return request .param
1083
+
1084
+ @pytest .fixture
1085
+ def replacer (self , how , from_key , to_key ):
1086
+ """
1087
+ Object we will pass to `Series.replace`
1088
+ """
1089
+ if how == "dict" :
1090
+ replacer = dict (zip (self .rep [from_key ], self .rep [to_key ]))
1091
+ elif how == "series" :
1092
+ replacer = pd .Series (self .rep [to_key ], index = self .rep [from_key ])
1093
+ else :
1094
+ raise ValueError
1095
+ return replacer
1096
+
1097
+ def test_replace_series (self , how , to_key , from_key , replacer ):
1084
1098
index = pd .Index ([3 , 4 ], name = "xxx" )
1085
1099
obj = pd .Series (self .rep [from_key ], index = index , name = "yyy" )
1086
1100
assert obj .dtype == from_key
@@ -1092,13 +1106,6 @@ def test_replace_series(self, how, to_key, from_key):
1092
1106
# tested below
1093
1107
return
1094
1108
1095
- if how == "dict" :
1096
- replacer = dict (zip (self .rep [from_key ], self .rep [to_key ]))
1097
- elif how == "series" :
1098
- replacer = pd .Series (self .rep [to_key ], index = self .rep [from_key ])
1099
- else :
1100
- raise ValueError
1101
-
1102
1109
result = obj .replace (replacer )
1103
1110
1104
1111
if (from_key == "float64" and to_key in ("int64" )) or (
@@ -1117,53 +1124,40 @@ def test_replace_series(self, how, to_key, from_key):
1117
1124
1118
1125
tm .assert_series_equal (result , exp )
1119
1126
1120
- @pytest .mark .parametrize ("how" , ["dict" , "series" ])
1121
1127
@pytest .mark .parametrize (
1122
1128
"to_key" ,
1123
1129
["timedelta64[ns]" , "bool" , "object" , "complex128" , "float64" , "int64" ],
1130
+ indirect = True ,
1124
1131
)
1125
1132
@pytest .mark .parametrize (
1126
- "from_key" , ["datetime64[ns, UTC]" , "datetime64[ns, US/Eastern]" ]
1133
+ "from_key" , ["datetime64[ns, UTC]" , "datetime64[ns, US/Eastern]" ], indirect = True
1127
1134
)
1128
- def test_replace_series_datetime_tz (self , how , to_key , from_key ):
1135
+ def test_replace_series_datetime_tz (self , how , to_key , from_key , replacer ):
1129
1136
index = pd .Index ([3 , 4 ], name = "xyz" )
1130
1137
obj = pd .Series (self .rep [from_key ], index = index , name = "yyy" )
1131
1138
assert obj .dtype == from_key
1132
1139
1133
- if how == "dict" :
1134
- replacer = dict (zip (self .rep [from_key ], self .rep [to_key ]))
1135
- elif how == "series" :
1136
- replacer = pd .Series (self .rep [to_key ], index = self .rep [from_key ])
1137
- else :
1138
- raise ValueError
1139
-
1140
1140
result = obj .replace (replacer )
1141
1141
exp = pd .Series (self .rep [to_key ], index = index , name = "yyy" )
1142
1142
assert exp .dtype == to_key
1143
1143
1144
1144
tm .assert_series_equal (result , exp )
1145
1145
1146
- @pytest .mark .parametrize ("how" , ["dict" , "series" ])
1147
1146
@pytest .mark .parametrize (
1148
1147
"to_key" ,
1149
1148
["datetime64[ns]" , "datetime64[ns, UTC]" , "datetime64[ns, US/Eastern]" ],
1149
+ indirect = True ,
1150
1150
)
1151
1151
@pytest .mark .parametrize (
1152
1152
"from_key" ,
1153
1153
["datetime64[ns]" , "datetime64[ns, UTC]" , "datetime64[ns, US/Eastern]" ],
1154
+ indirect = True ,
1154
1155
)
1155
- def test_replace_series_datetime_datetime (self , how , to_key , from_key ):
1156
+ def test_replace_series_datetime_datetime (self , how , to_key , from_key , replacer ):
1156
1157
index = pd .Index ([3 , 4 ], name = "xyz" )
1157
1158
obj = pd .Series (self .rep [from_key ], index = index , name = "yyy" )
1158
1159
assert obj .dtype == from_key
1159
1160
1160
- if how == "dict" :
1161
- replacer = dict (zip (self .rep [from_key ], self .rep [to_key ]))
1162
- elif how == "series" :
1163
- replacer = pd .Series (self .rep [to_key ], index = self .rep [from_key ])
1164
- else :
1165
- raise ValueError
1166
-
1167
1161
result = obj .replace (replacer )
1168
1162
exp = pd .Series (self .rep [to_key ], index = index , name = "yyy" )
1169
1163
assert exp .dtype == to_key
0 commit comments