Skip to content

Commit 81a1927

Browse files
samukwekusamuel.oranyeliericmjl
authored
[ENH] Method for adding functionality to pandas groupby (#1462)
* add accessor support for pandas DataFrameGroupBy * add accessor support for pandas DataFrameGroupBy * add accessor support for pandas DataFrameGroupBy * add pull method * add comments showing source for groupby accessor * add comments showing source for groupby accessor * update pandas_flavor * changelog * ensure pandas_flavor > 0.6 * Update environment-dev.yml * exclude pull - will introduce a better API * minor fixes --------- Co-authored-by: samuel.oranyeli <[email protected]> Co-authored-by: Eric Ma <[email protected]>
1 parent 0673f30 commit 81a1927

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
## [Unreleased]
44
- [ENH] Added `row_count` parameter for janitor.conditional_join - Issue #1269 @samukweku
5-
- [ENG] Reverse deprecation of `pivot_wider()` -- Issue #1464
5+
- [ENH] Reverse deprecation of `pivot_wider()` -- Issue #1464
6+
- [ENH] Add accessor and method for pandas DataFrameGroupBy objects. - Issue #587 @samukweku
7+
68
## [v0.31.0] - 2025-03-07
79

810
- [ENH] Added support for pd.Series.select - Issue #1394 @samukweku

janitor/functions/select.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,17 +327,18 @@ def select_rows(
327327
return _select(df, rows=list(args), invert=invert)
328328

329329

330+
@pf.register_groupby_method
330331
@pf.register_dataframe_method
331332
@pf.register_series_method
332333
@deprecated_alias(rows="index")
333334
def select(
334-
df: pd.DataFrame | pd.Series,
335+
df: pd.DataFrame | pd.Series | DataFrameGroupBy,
335336
*args: tuple,
336337
index: Any = None,
337338
columns: Any = None,
338339
axis: str = "columns",
339340
invert: bool = False,
340-
) -> pd.DataFrame | pd.Series:
341+
) -> pd.DataFrame | pd.Series | DataFrameGroupBy:
341342
"""Method-chainable selection of rows and/or columns.
342343
343344
It accepts a string, shell-like glob strings `(*string*)`,
@@ -371,6 +372,8 @@ def select(
371372
- `rows` keyword deprecated in favour of `index`.
372373
- 0.31.0
373374
- Add support for pd.Series.
375+
- 0.32.0
376+
- Add support for DataFrameGroupBy.
374377
375378
Examples:
376379
>>> import pandas as pd
@@ -436,6 +439,10 @@ def select(
436439
Returns:
437440
A pandas DataFrame or Series with the specified rows and/or columns selected.
438441
""" # noqa: E501
442+
if args and isinstance(df, DataFrameGroupBy):
443+
return get_columns(group=df, label=list(args))
444+
if isinstance(df, DataFrameGroupBy):
445+
return get_columns(group=df, label=[columns])
439446
if args:
440447
check("invert", invert, [bool])
441448
if (index is not None) or (columns is not None):
@@ -478,6 +485,12 @@ def get_index_labels(
478485
return index[_select_index(arg, df, axis)]
479486

480487

488+
@refactored_function(
489+
message=(
490+
"This function will be deprecated in a 1.x release. "
491+
"Please use `jn.select` instead."
492+
)
493+
)
481494
def get_columns(
482495
group: DataFrameGroupBy | SeriesGroupBy, label: Any
483496
) -> DataFrameGroupBy | SeriesGroupBy:
@@ -488,6 +501,11 @@ def get_columns(
488501
489502
!!! info "New in version 0.25.0"
490503
504+
!!!note
505+
506+
This function will be deprecated in a 1.x release.
507+
Please use `jn.select` instead.
508+
491509
Args:
492510
group: A Pandas GroupBy object.
493511
label: column(s) to select.

tests/functions/test_select_columns.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,20 @@ def test_select_groupby(dataframe):
479479
assert_frame_equal(expected, actual)
480480

481481

482+
def test_select_groupby_args(dataframe):
483+
"""Test output on a grouped object"""
484+
expected = dataframe.select_dtypes("number").groupby(dataframe["a"]).sum()
485+
actual = dataframe.groupby("a").select(is_numeric_dtype).sum()
486+
assert_frame_equal(expected, actual)
487+
488+
489+
def test_select_groupby_columns(dataframe):
490+
"""Test output on a grouped object"""
491+
expected = dataframe.select_dtypes("number").groupby(dataframe["a"]).sum()
492+
actual = dataframe.groupby("a").select(columns=is_numeric_dtype).sum()
493+
assert_frame_equal(expected, actual)
494+
495+
482496
def test_select_str_multiindex(multiindex):
483497
"""Test str selection on a MultiIndex - exact match"""
484498
expected = multiindex.select_columns("bar")

0 commit comments

Comments
 (0)