Skip to content

Commit 3e20eab

Browse files
Keiron Pizzeyjreback
Keiron Pizzey
authored andcommitted
ENH - Modify Dataframe.select_dtypes to accept scalar values (#16860)
1 parent a43c157 commit 3e20eab

File tree

5 files changed

+130
-33
lines changed

5 files changed

+130
-33
lines changed

doc/source/basics.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,7 +2229,3 @@ All numpy dtypes are subclasses of ``numpy.generic``:
22292229

22302230
Pandas also defines the types ``category``, and ``datetime64[ns, tz]``, which are not integrated into the normal
22312231
numpy hierarchy and wont show up with the above function.
2232-
2233-
.. note::
2234-
2235-
The ``include`` and ``exclude`` parameters must be non-string sequences.

doc/source/style.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@
935935
"\n",
936936
"<span style=\"color: red\">*Experimental: This is a new feature and still under development. We'll be adding features and possibly making breaking changes in future releases. We'd love to hear your feedback.*</span>\n",
937937
"\n",
938-
"Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n",
938+
"Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n",
939939
"\n",
940940
"- `background-color`\n",
941941
"- `border-style`, `border-width`, `border-color` and their {`top`, `right`, `bottom`, `left` variants}\n",

doc/source/whatsnew/v0.21.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Other Enhancements
3939
- :func:`read_feather` has gained the ``nthreads`` parameter for multi-threaded operations (:issue:`16359`)
4040
- :func:`DataFrame.clip()` and :func:`Series.clip()` have gained an ``inplace`` argument. (:issue:`15388`)
4141
- :func:`crosstab` has gained a ``margins_name`` parameter to define the name of the row / column that will contain the totals when ``margins=True``. (:issue:`15972`)
42+
- :func:`Dataframe.select_dtypes` now accepts scalar values for include/exclude as well as list-like. (:issue:`16855`)
4243

4344
.. _whatsnew_0210.api_breaking:
4445

pandas/core/frame.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,18 +2285,16 @@ def select_dtypes(self, include=None, exclude=None):
22852285
22862286
Parameters
22872287
----------
2288-
include, exclude : list-like
2289-
A list of dtypes or strings to be included/excluded. You must pass
2290-
in a non-empty sequence for at least one of these.
2288+
include, exclude : scalar or list-like
2289+
A selection of dtypes or strings to be included/excluded. At least
2290+
one of these parameters must be supplied.
22912291
22922292
Raises
22932293
------
22942294
ValueError
22952295
* If both of ``include`` and ``exclude`` are empty
22962296
* If ``include`` and ``exclude`` have overlapping elements
22972297
* If any kind of string dtype is passed in.
2298-
TypeError
2299-
* If either of ``include`` or ``exclude`` is not a sequence
23002298
23012299
Returns
23022300
-------
@@ -2331,6 +2329,14 @@ def select_dtypes(self, include=None, exclude=None):
23312329
3 0.0764 False 2
23322330
4 -0.9703 True 1
23332331
5 -1.2094 False 2
2332+
>>> df.select_dtypes(include='bool')
2333+
c
2334+
0 True
2335+
1 False
2336+
2 True
2337+
3 False
2338+
4 True
2339+
5 False
23342340
>>> df.select_dtypes(include=['float64'])
23352341
c
23362342
0 1
@@ -2348,10 +2354,12 @@ def select_dtypes(self, include=None, exclude=None):
23482354
4 True
23492355
5 False
23502356
"""
2351-
include, exclude = include or (), exclude or ()
2352-
if not (is_list_like(include) and is_list_like(exclude)):
2353-
raise TypeError('include and exclude must both be non-string'
2354-
' sequences')
2357+
2358+
if not is_list_like(include):
2359+
include = (include,) if include is not None else ()
2360+
if not is_list_like(exclude):
2361+
exclude = (exclude,) if exclude is not None else ()
2362+
23552363
selection = tuple(map(frozenset, (include, exclude)))
23562364

23572365
if not any(selection):

pandas/tests/frame/test_dtypes.py

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_dtypes_are_correct_after_column_slice(self):
104104
('b', np.float_),
105105
('c', np.float_)])))
106106

107-
def test_select_dtypes_include(self):
107+
def test_select_dtypes_include_using_list_like(self):
108108
df = DataFrame({'a': list('abc'),
109109
'b': list(range(1, 4)),
110110
'c': np.arange(3, 6).astype('u1'),
@@ -145,14 +145,10 @@ def test_select_dtypes_include(self):
145145
ei = df[['h', 'i']]
146146
assert_frame_equal(ri, ei)
147147

148-
ri = df.select_dtypes(include=['timedelta'])
149-
ei = df[['k']]
150-
assert_frame_equal(ri, ei)
151-
152148
pytest.raises(NotImplementedError,
153149
lambda: df.select_dtypes(include=['period']))
154150

155-
def test_select_dtypes_exclude(self):
151+
def test_select_dtypes_exclude_using_list_like(self):
156152
df = DataFrame({'a': list('abc'),
157153
'b': list(range(1, 4)),
158154
'c': np.arange(3, 6).astype('u1'),
@@ -162,7 +158,7 @@ def test_select_dtypes_exclude(self):
162158
ee = df[['a', 'e']]
163159
assert_frame_equal(re, ee)
164160

165-
def test_select_dtypes_exclude_include(self):
161+
def test_select_dtypes_exclude_include_using_list_like(self):
166162
df = DataFrame({'a': list('abc'),
167163
'b': list(range(1, 4)),
168164
'c': np.arange(3, 6).astype('u1'),
@@ -181,6 +177,114 @@ def test_select_dtypes_exclude_include(self):
181177
e = df[['b', 'e']]
182178
assert_frame_equal(r, e)
183179

180+
def test_select_dtypes_include_using_scalars(self):
181+
df = DataFrame({'a': list('abc'),
182+
'b': list(range(1, 4)),
183+
'c': np.arange(3, 6).astype('u1'),
184+
'd': np.arange(4.0, 7.0, dtype='float64'),
185+
'e': [True, False, True],
186+
'f': pd.Categorical(list('abc')),
187+
'g': pd.date_range('20130101', periods=3),
188+
'h': pd.date_range('20130101', periods=3,
189+
tz='US/Eastern'),
190+
'i': pd.date_range('20130101', periods=3,
191+
tz='CET'),
192+
'j': pd.period_range('2013-01', periods=3,
193+
freq='M'),
194+
'k': pd.timedelta_range('1 day', periods=3)})
195+
196+
ri = df.select_dtypes(include=np.number)
197+
ei = df[['b', 'c', 'd', 'k']]
198+
assert_frame_equal(ri, ei)
199+
200+
ri = df.select_dtypes(include='datetime')
201+
ei = df[['g']]
202+
assert_frame_equal(ri, ei)
203+
204+
ri = df.select_dtypes(include='datetime64')
205+
ei = df[['g']]
206+
assert_frame_equal(ri, ei)
207+
208+
ri = df.select_dtypes(include='category')
209+
ei = df[['f']]
210+
assert_frame_equal(ri, ei)
211+
212+
pytest.raises(NotImplementedError,
213+
lambda: df.select_dtypes(include='period'))
214+
215+
def test_select_dtypes_exclude_using_scalars(self):
216+
df = DataFrame({'a': list('abc'),
217+
'b': list(range(1, 4)),
218+
'c': np.arange(3, 6).astype('u1'),
219+
'd': np.arange(4.0, 7.0, dtype='float64'),
220+
'e': [True, False, True],
221+
'f': pd.Categorical(list('abc')),
222+
'g': pd.date_range('20130101', periods=3),
223+
'h': pd.date_range('20130101', periods=3,
224+
tz='US/Eastern'),
225+
'i': pd.date_range('20130101', periods=3,
226+
tz='CET'),
227+
'j': pd.period_range('2013-01', periods=3,
228+
freq='M'),
229+
'k': pd.timedelta_range('1 day', periods=3)})
230+
231+
ri = df.select_dtypes(exclude=np.number)
232+
ei = df[['a', 'e', 'f', 'g', 'h', 'i', 'j']]
233+
assert_frame_equal(ri, ei)
234+
235+
ri = df.select_dtypes(exclude='category')
236+
ei = df[['a', 'b', 'c', 'd', 'e', 'g', 'h', 'i', 'j', 'k']]
237+
assert_frame_equal(ri, ei)
238+
239+
pytest.raises(NotImplementedError,
240+
lambda: df.select_dtypes(exclude='period'))
241+
242+
def test_select_dtypes_include_exclude_using_scalars(self):
243+
df = DataFrame({'a': list('abc'),
244+
'b': list(range(1, 4)),
245+
'c': np.arange(3, 6).astype('u1'),
246+
'd': np.arange(4.0, 7.0, dtype='float64'),
247+
'e': [True, False, True],
248+
'f': pd.Categorical(list('abc')),
249+
'g': pd.date_range('20130101', periods=3),
250+
'h': pd.date_range('20130101', periods=3,
251+
tz='US/Eastern'),
252+
'i': pd.date_range('20130101', periods=3,
253+
tz='CET'),
254+
'j': pd.period_range('2013-01', periods=3,
255+
freq='M'),
256+
'k': pd.timedelta_range('1 day', periods=3)})
257+
258+
ri = df.select_dtypes(include=np.number, exclude='floating')
259+
ei = df[['b', 'c', 'k']]
260+
assert_frame_equal(ri, ei)
261+
262+
def test_select_dtypes_include_exclude_mixed_scalars_lists(self):
263+
df = DataFrame({'a': list('abc'),
264+
'b': list(range(1, 4)),
265+
'c': np.arange(3, 6).astype('u1'),
266+
'd': np.arange(4.0, 7.0, dtype='float64'),
267+
'e': [True, False, True],
268+
'f': pd.Categorical(list('abc')),
269+
'g': pd.date_range('20130101', periods=3),
270+
'h': pd.date_range('20130101', periods=3,
271+
tz='US/Eastern'),
272+
'i': pd.date_range('20130101', periods=3,
273+
tz='CET'),
274+
'j': pd.period_range('2013-01', periods=3,
275+
freq='M'),
276+
'k': pd.timedelta_range('1 day', periods=3)})
277+
278+
ri = df.select_dtypes(include=np.number,
279+
exclude=['floating', 'timedelta'])
280+
ei = df[['b', 'c']]
281+
assert_frame_equal(ri, ei)
282+
283+
ri = df.select_dtypes(include=[np.number, 'category'],
284+
exclude='floating')
285+
ei = df[['b', 'c', 'f', 'k']]
286+
assert_frame_equal(ri, ei)
287+
184288
def test_select_dtypes_not_an_attr_but_still_valid_dtype(self):
185289
df = DataFrame({'a': list('abc'),
186290
'b': list(range(1, 4)),
@@ -205,18 +309,6 @@ def test_select_dtypes_empty(self):
205309
'must be nonempty'):
206310
df.select_dtypes()
207311

208-
def test_select_dtypes_raises_on_string(self):
209-
df = DataFrame({'a': list('abc'), 'b': list(range(1, 4))})
210-
with tm.assert_raises_regex(TypeError, 'include and exclude '
211-
'.+ non-'):
212-
df.select_dtypes(include='object')
213-
with tm.assert_raises_regex(TypeError, 'include and exclude '
214-
'.+ non-'):
215-
df.select_dtypes(exclude='object')
216-
with tm.assert_raises_regex(TypeError, 'include and exclude '
217-
'.+ non-'):
218-
df.select_dtypes(include=int, exclude='object')
219-
220312
def test_select_dtypes_bad_datetime64(self):
221313
df = DataFrame({'a': list('abc'),
222314
'b': list(range(1, 4)),

0 commit comments

Comments
 (0)