Skip to content

Commit fb9d532

Browse files
committed
review edits
1 parent b5e803c commit fb9d532

File tree

3 files changed

+130
-68
lines changed

3 files changed

+130
-68
lines changed

pandas/conftest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ def join_type(request):
129129
return request.param
130130

131131

132+
@pytest.fixture(params=['nlargest', 'nsmallest'])
133+
def nselect_method(request):
134+
"""
135+
Fixture for trying all nselect methods
136+
"""
137+
return request.param
138+
139+
132140
@pytest.fixture(params=[None, np.nan, pd.NaT, float('nan'), np.float('NaN')])
133141
def nulls_fixture(request):
134142
"""
@@ -170,3 +178,66 @@ def string_dtype(request):
170178
* 'U'
171179
"""
172180
return request.param
181+
182+
183+
@pytest.fixture(params=["float32", "float64"])
184+
def float_dtype(request):
185+
"""
186+
Parameterized fixture for float dtypes.
187+
188+
* float32
189+
* float64
190+
"""
191+
192+
return request.param
193+
194+
195+
UNSIGNED_INT_DTYPES = ["uint8", "uint16", "uint32", "uint64"]
196+
SIGNED_INT_DTYPES = ["int8", "int16", "int32", "int64"]
197+
ALL_INT_DTYPES = UNSIGNED_INT_DTYPES + SIGNED_INT_DTYPES
198+
199+
200+
@pytest.fixture(params=SIGNED_INT_DTYPES)
201+
def sint_dtype(request):
202+
"""
203+
Parameterized fixture for signed integer dtypes.
204+
205+
* int8
206+
* int16
207+
* int32
208+
* int64
209+
"""
210+
211+
return request.param
212+
213+
214+
@pytest.fixture(params=UNSIGNED_INT_DTYPES)
215+
def uint_dtype(request):
216+
"""
217+
Parameterized fixture for unsigned integer dtypes.
218+
219+
* uint8
220+
* uint16
221+
* uint32
222+
* uint64
223+
"""
224+
225+
return request.param
226+
227+
228+
@pytest.fixture(params=ALL_INT_DTYPES)
229+
def any_int_dtype(request):
230+
"""
231+
Parameterized fixture for any integer dtypes.
232+
233+
* int8
234+
* uint8
235+
* int16
236+
* uint16
237+
* int32
238+
* uint32
239+
* int64
240+
* uint64
241+
"""
242+
243+
return request.param

pandas/tests/frame/test_analytics.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from numpy.random import randn
1313
import numpy as np
1414

15-
from pandas.compat import lrange, product, PY35
15+
from pandas.compat import lrange, PY35
1616
from pandas import (compat, isna, notna, DataFrame, Series,
1717
MultiIndex, date_range, Timestamp, Categorical,
1818
_np_version_under1p12, _np_version_under1p15,
@@ -2260,54 +2260,49 @@ class TestNLargestNSmallest(object):
22602260

22612261
# ----------------------------------------------------------------------
22622262
# Top / bottom
2263-
@pytest.mark.parametrize(
2264-
'method, n, order',
2265-
product(['nsmallest', 'nlargest'], range(1, 11),
2266-
[['a'],
2267-
['c'],
2268-
['a', 'b'],
2269-
['a', 'c'],
2270-
['b', 'a'],
2271-
['b', 'c'],
2272-
['a', 'b', 'c'],
2273-
['c', 'a', 'b'],
2274-
['c', 'b', 'a'],
2275-
['b', 'c', 'a'],
2276-
['b', 'a', 'c'],
2277-
2278-
# dups!
2279-
['b', 'c', 'c'],
2280-
2281-
]))
2282-
def test_n(self, df_strings, method, n, order):
2263+
@pytest.mark.parametrize('order', [
2264+
['a'],
2265+
['c'],
2266+
['a', 'b'],
2267+
['a', 'c'],
2268+
['b', 'a'],
2269+
['b', 'c'],
2270+
['a', 'b', 'c'],
2271+
['c', 'a', 'b'],
2272+
['c', 'b', 'a'],
2273+
['b', 'c', 'a'],
2274+
['b', 'a', 'c'],
2275+
2276+
# dups!
2277+
['b', 'c', 'c']])
2278+
@pytest.mark.parametrize('n', range(1, 11))
2279+
def test_n(self, df_strings, nselect_method, n, order):
22832280
# GH10393
22842281
df = df_strings
22852282
if 'b' in order:
22862283

22872284
error_msg = self.dtype_error_msg_template.format(
2288-
column='b', method=method, dtype='object')
2285+
column='b', method=nselect_method, dtype='object')
22892286
with tm.assert_raises_regex(TypeError, error_msg):
2290-
getattr(df, method)(n, order)
2287+
getattr(df, nselect_method)(n, order)
22912288
else:
2292-
ascending = method == 'nsmallest'
2293-
result = getattr(df, method)(n, order)
2289+
ascending = nselect_method == 'nsmallest'
2290+
result = getattr(df, nselect_method)(n, order)
22942291
expected = df.sort_values(order, ascending=ascending).head(n)
22952292
tm.assert_frame_equal(result, expected)
22962293

2297-
@pytest.mark.parametrize(
2298-
'method, columns',
2299-
product(['nsmallest', 'nlargest'],
2300-
product(['group'], ['category_string', 'string'])
2301-
))
2302-
def test_n_error(self, df_main_dtypes, method, columns):
2294+
@pytest.mark.parametrize('columns', [
2295+
('group', 'category_string'), ('group', 'string')])
2296+
def test_n_error(self, df_main_dtypes, nselect_method, columns):
23032297
df = df_main_dtypes
2298+
col = columns[1]
23042299
error_msg = self.dtype_error_msg_template.format(
2305-
column=columns[1], method=method, dtype=df[columns[1]].dtype)
2300+
column=col, method=nselect_method, dtype=df[col].dtype)
23062301
# escape some characters that may be in the repr
23072302
error_msg = (error_msg.replace('(', '\\(').replace(")", "\\)")
23082303
.replace("[", "\\[").replace("]", "\\]"))
23092304
with tm.assert_raises_regex(TypeError, error_msg):
2310-
getattr(df, method)(2, columns)
2305+
getattr(df, nselect_method)(2, columns)
23112306

23122307
def test_n_all_dtypes(self, df_main_dtypes):
23132308
df = df_main_dtypes
@@ -2328,15 +2323,14 @@ def test_n_identical_values(self):
23282323
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]})
23292324
tm.assert_frame_equal(result, expected)
23302325

2331-
@pytest.mark.parametrize(
2332-
'n, order',
2333-
product([1, 2, 3, 4, 5],
2334-
[['a', 'b', 'c'],
2335-
['c', 'b', 'a'],
2336-
['a'],
2337-
['b'],
2338-
['a', 'b'],
2339-
['c', 'b']]))
2326+
@pytest.mark.parametrize('order', [
2327+
['a', 'b', 'c'],
2328+
['c', 'b', 'a'],
2329+
['a'],
2330+
['b'],
2331+
['a', 'b'],
2332+
['c', 'b']])
2333+
@pytest.mark.parametrize('n', range(1, 6))
23402334
def test_n_duplicate_index(self, df_duplicates, n, order):
23412335
# GH 13412
23422336

pandas/tests/series/test_analytics.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,11 +1944,16 @@ def test_mode_sortwarning(self):
19441944
tm.assert_series_equal(result, expected)
19451945

19461946

1947-
class TestNLargestNSmallest(object):
1947+
def assert_check_nselect_boundary(vals, dtype, method):
1948+
# helper function for 'test_boundary_{dtype}' tests
1949+
s = Series(vals, dtype=dtype)
1950+
result = getattr(s, method)(3)
1951+
expected_idxr = [0, 1, 2] if method == 'nsmallest' else [3, 2, 1]
1952+
expected = s.loc[expected_idxr]
1953+
tm.assert_series_equal(result, expected)
1954+
19481955

1949-
@pytest.fixture(params=['nlargest', 'nsmallest'])
1950-
def method(self, request):
1951-
return request.param
1956+
class TestNLargestNSmallest(object):
19521957

19531958
@pytest.mark.parametrize(
19541959
"r", [Series([3., 2, 1, 2, '5'], dtype='object'),
@@ -2032,39 +2037,31 @@ def test_n(self, n):
20322037
expected = s.sort_values().head(n)
20332038
assert_series_equal(result, expected)
20342039

2035-
def _check_nselect_boundary(self, vals, dtype, method):
2036-
# helper function for 'test_boundary_dtype' tests
2037-
s = Series(vals, dtype=dtype)
2038-
result = getattr(s, method)(3)
2039-
expected_idxr = [0, 1, 2] if method == 'nsmallest' else [3, 2, 1]
2040-
expected = s.loc[expected_idxr]
2041-
tm.assert_series_equal(result, expected)
2042-
2043-
@pytest.mark.parametrize('dtype', [
2044-
'int8', 'int16', 'int32', 'int64',
2045-
'uint8', 'uint16', 'uint32', 'uint64'])
2046-
def test_boundary_integer(self, method, dtype):
2040+
def test_boundary_integer(self, nselect_method, any_int_dtype):
20472041
# GH 21426
2048-
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
2042+
dtype_info = np.iinfo(any_int_dtype)
2043+
min_val, max_val = dtype_info.min, dtype_info.max
20492044
vals = [min_val, min_val + 1, max_val - 1, max_val]
2050-
self._check_nselect_boundary(vals, dtype, method)
2045+
assert_check_nselect_boundary(vals, any_int_dtype, nselect_method)
20512046

2052-
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
2053-
def test_boundary_float(self, method, dtype):
2047+
def test_boundary_float(self, nselect_method, float_dtype):
20542048
# GH 21426
2055-
min_val, max_val = np.finfo(dtype).min, np.finfo(dtype).max
2056-
min_2nd, max_2nd = np.nextafter([min_val, max_val], 0, dtype=dtype)
2049+
dtype_info = np.finfo(float_dtype)
2050+
min_val, max_val = dtype_info.min, dtype_info.max
2051+
min_2nd, max_2nd = np.nextafter(
2052+
[min_val, max_val], 0, dtype=float_dtype)
20572053
vals = [min_val, min_2nd, max_2nd, max_val]
2058-
self._check_nselect_boundary(vals, dtype, method)
2054+
assert_check_nselect_boundary(vals, float_dtype, nselect_method)
20592055

20602056
@pytest.mark.parametrize('dtype', ['datetime64[ns]', 'timedelta64[ns]'])
2061-
def test_boundary_datetimelike(self, method, dtype):
2057+
def test_boundary_datetimelike(self, nselect_method, dtype):
20622058
# GH 21426
20632059
# use int64 bounds and +1 to min_val since true minimum is NaT
20642060
# (include min_val/NaT at end to maintain same expected_idxr)
2065-
min_val, max_val = np.iinfo('int64').min, np.iinfo('int64').max
2061+
dtype_info = np.iinfo('int64')
2062+
min_val, max_val = dtype_info.min, dtype_info.max
20662063
vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
2067-
self._check_nselect_boundary(vals, dtype, method)
2064+
assert_check_nselect_boundary(vals, dtype, nselect_method)
20682065

20692066

20702067
class TestCategoricalSeriesAnalytics(object):

0 commit comments

Comments
 (0)