Skip to content

Commit b3f483f

Browse files
authored
ENH: Implement groupby.sample (#34069)
1 parent 83016f3 commit b3f483f

File tree

7 files changed

+246
-0
lines changed

7 files changed

+246
-0
lines changed

doc/source/reference/groupby.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ application to columns of a specific data type.
116116
DataFrameGroupBy.quantile
117117
DataFrameGroupBy.rank
118118
DataFrameGroupBy.resample
119+
DataFrameGroupBy.sample
119120
DataFrameGroupBy.shift
120121
DataFrameGroupBy.size
121122
DataFrameGroupBy.skew

doc/source/whatsnew/v1.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ Other enhancements
275275
such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`)
276276
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
277277
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
278+
- :class:`~pandas.core.groupby.generic.DataFrameGroupBy` and :class:`~pandas.core.groupby.generic.SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`)
278279
- :meth:`DataFrame.to_numpy` now supports the ``na_value`` keyword to control the NA sentinel in the output array (:issue:`33820`)
279280
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
280281
method, similarly to :meth:`Series.equals` (:issue:`27081`).

pandas/core/generic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4868,6 +4868,10 @@ def sample(
48684868
48694869
See Also
48704870
--------
4871+
DataFrameGroupBy.sample: Generates random samples from each group of a
4872+
DataFrame object.
4873+
SeriesGroupBy.sample: Generates random samples from each group of a
4874+
Series object.
48714875
numpy.random.choice: Generates a random sample from a given 1-D numpy
48724876
array.
48734877

pandas/core/groupby/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _gotitem(self, key, ndim, subset=None):
180180
"tail",
181181
"take",
182182
"transform",
183+
"sample",
183184
]
184185
)
185186
# Valid values of `name` for `groupby.transform(name)`

pandas/core/groupby/groupby.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class providing the base-class of operations.
2323
List,
2424
Mapping,
2525
Optional,
26+
Sequence,
2627
Tuple,
2728
Type,
2829
TypeVar,
@@ -2695,6 +2696,118 @@ def _reindex_output(
26952696

26962697
return output.reset_index(drop=True)
26972698

2699+
def sample(
2700+
self,
2701+
n: Optional[int] = None,
2702+
frac: Optional[float] = None,
2703+
replace: bool = False,
2704+
weights: Optional[Union[Sequence, Series]] = None,
2705+
random_state=None,
2706+
):
2707+
"""
2708+
Return a random sample of items from each group.
2709+
2710+
You can use `random_state` for reproducibility.
2711+
2712+
.. versionadded:: 1.1.0
2713+
2714+
Parameters
2715+
----------
2716+
n : int, optional
2717+
Number of items to return for each group. Cannot be used with
2718+
`frac` and must be no larger than the smallest group unless
2719+
`replace` is True. Default is one if `frac` is None.
2720+
frac : float, optional
2721+
Fraction of items to return. Cannot be used with `n`.
2722+
replace : bool, default False
2723+
Allow or disallow sampling of the same row more than once.
2724+
weights : list-like, optional
2725+
Default None results in equal probability weighting.
2726+
If passed a list-like then values must have the same length as
2727+
the underlying DataFrame or Series object and will be used as
2728+
sampling probabilities after normalization within each group.
2729+
Values must be non-negative with at least one positive element
2730+
within each group.
2731+
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
2732+
If int, array-like, or BitGenerator (NumPy>=1.17), seed for
2733+
random number generator
2734+
If np.random.RandomState, use as numpy RandomState object.
2735+
2736+
Returns
2737+
-------
2738+
Series or DataFrame
2739+
A new object of same type as caller containing items randomly
2740+
sampled within each group from the caller object.
2741+
2742+
See Also
2743+
--------
2744+
DataFrame.sample: Generate random samples from a DataFrame object.
2745+
numpy.random.choice: Generate a random sample from a given 1-D numpy
2746+
array.
2747+
2748+
Examples
2749+
--------
2750+
>>> df = pd.DataFrame(
2751+
... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)}
2752+
... )
2753+
>>> df
2754+
a b
2755+
0 red 0
2756+
1 red 1
2757+
2 blue 2
2758+
3 blue 3
2759+
4 black 4
2760+
5 black 5
2761+
2762+
Select one row at random for each distinct value in column a. The
2763+
`random_state` argument can be used to guarantee reproducibility:
2764+
2765+
>>> df.groupby("a").sample(n=1, random_state=1)
2766+
a b
2767+
4 black 4
2768+
2 blue 2
2769+
1 red 1
2770+
2771+
Set `frac` to sample fixed proportions rather than counts:
2772+
2773+
>>> df.groupby("a")["b"].sample(frac=0.5, random_state=2)
2774+
5 5
2775+
2 2
2776+
0 0
2777+
Name: b, dtype: int64
2778+
2779+
Control sample probabilities within groups by setting weights:
2780+
2781+
>>> df.groupby("a").sample(
2782+
... n=1,
2783+
... weights=[1, 1, 1, 0, 0, 1],
2784+
... random_state=1,
2785+
... )
2786+
a b
2787+
5 black 5
2788+
2 blue 2
2789+
0 red 0
2790+
"""
2791+
from pandas.core.reshape.concat import concat
2792+
2793+
if weights is not None:
2794+
weights = Series(weights, index=self._selected_obj.index)
2795+
ws = [weights[idx] for idx in self.indices.values()]
2796+
else:
2797+
ws = [None] * self.ngroups
2798+
2799+
if random_state is not None:
2800+
random_state = com.random_state(random_state)
2801+
2802+
samples = [
2803+
obj.sample(
2804+
n=n, frac=frac, replace=replace, weights=w, random_state=random_state
2805+
)
2806+
for (_, obj), w in zip(self, ws)
2807+
]
2808+
2809+
return concat(samples, axis=self.axis)
2810+
26982811

26992812
@doc(GroupBy)
27002813
def get_groupby(

pandas/tests/groupby/test_sample.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import pytest
2+
3+
from pandas import DataFrame, Index, Series
4+
import pandas._testing as tm
5+
6+
7+
@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)])
8+
def test_groupby_sample_balanced_groups_shape(n, frac):
9+
values = [1] * 10 + [2] * 10
10+
df = DataFrame({"a": values, "b": values})
11+
12+
result = df.groupby("a").sample(n=n, frac=frac)
13+
values = [1] * 2 + [2] * 2
14+
expected = DataFrame({"a": values, "b": values}, index=result.index)
15+
tm.assert_frame_equal(result, expected)
16+
17+
result = df.groupby("a")["b"].sample(n=n, frac=frac)
18+
expected = Series(values, name="b", index=result.index)
19+
tm.assert_series_equal(result, expected)
20+
21+
22+
def test_groupby_sample_unbalanced_groups_shape():
23+
values = [1] * 10 + [2] * 20
24+
df = DataFrame({"a": values, "b": values})
25+
26+
result = df.groupby("a").sample(n=5)
27+
values = [1] * 5 + [2] * 5
28+
expected = DataFrame({"a": values, "b": values}, index=result.index)
29+
tm.assert_frame_equal(result, expected)
30+
31+
result = df.groupby("a")["b"].sample(n=5)
32+
expected = Series(values, name="b", index=result.index)
33+
tm.assert_series_equal(result, expected)
34+
35+
36+
def test_groupby_sample_index_value_spans_groups():
37+
values = [1] * 3 + [2] * 3
38+
df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2])
39+
40+
result = df.groupby("a").sample(n=2)
41+
values = [1] * 2 + [2] * 2
42+
expected = DataFrame({"a": values, "b": values}, index=result.index)
43+
tm.assert_frame_equal(result, expected)
44+
45+
result = df.groupby("a")["b"].sample(n=2)
46+
expected = Series(values, name="b", index=result.index)
47+
tm.assert_series_equal(result, expected)
48+
49+
50+
def test_groupby_sample_n_and_frac_raises():
51+
df = DataFrame({"a": [1, 2], "b": [1, 2]})
52+
msg = "Please enter a value for `frac` OR `n`, not both"
53+
54+
with pytest.raises(ValueError, match=msg):
55+
df.groupby("a").sample(n=1, frac=1.0)
56+
57+
with pytest.raises(ValueError, match=msg):
58+
df.groupby("a")["b"].sample(n=1, frac=1.0)
59+
60+
61+
def test_groupby_sample_frac_gt_one_without_replacement_raises():
62+
df = DataFrame({"a": [1, 2], "b": [1, 2]})
63+
msg = "Replace has to be set to `True` when upsampling the population `frac` > 1."
64+
65+
with pytest.raises(ValueError, match=msg):
66+
df.groupby("a").sample(frac=1.5, replace=False)
67+
68+
with pytest.raises(ValueError, match=msg):
69+
df.groupby("a")["b"].sample(frac=1.5, replace=False)
70+
71+
72+
@pytest.mark.parametrize("n", [-1, 1.5])
73+
def test_groupby_sample_invalid_n_raises(n):
74+
df = DataFrame({"a": [1, 2], "b": [1, 2]})
75+
76+
if n < 0:
77+
msg = "Please provide positive value"
78+
else:
79+
msg = "Only integers accepted as `n` values"
80+
81+
with pytest.raises(ValueError, match=msg):
82+
df.groupby("a").sample(n=n)
83+
84+
with pytest.raises(ValueError, match=msg):
85+
df.groupby("a")["b"].sample(n=n)
86+
87+
88+
def test_groupby_sample_oversample():
89+
values = [1] * 10 + [2] * 10
90+
df = DataFrame({"a": values, "b": values})
91+
92+
result = df.groupby("a").sample(frac=2.0, replace=True)
93+
values = [1] * 20 + [2] * 20
94+
expected = DataFrame({"a": values, "b": values}, index=result.index)
95+
tm.assert_frame_equal(result, expected)
96+
97+
result = df.groupby("a")["b"].sample(frac=2.0, replace=True)
98+
expected = Series(values, name="b", index=result.index)
99+
tm.assert_series_equal(result, expected)
100+
101+
102+
def test_groupby_sample_without_n_or_frac():
103+
values = [1] * 10 + [2] * 10
104+
df = DataFrame({"a": values, "b": values})
105+
106+
result = df.groupby("a").sample(n=None, frac=None)
107+
expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index)
108+
tm.assert_frame_equal(result, expected)
109+
110+
result = df.groupby("a")["b"].sample(n=None, frac=None)
111+
expected = Series([1, 2], name="b", index=result.index)
112+
tm.assert_series_equal(result, expected)
113+
114+
115+
def test_groupby_sample_with_weights():
116+
values = [1] * 2 + [2] * 2
117+
df = DataFrame({"a": values, "b": values}, index=Index(["w", "x", "y", "z"]))
118+
119+
result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0])
120+
expected = DataFrame({"a": values, "b": values}, index=Index(["w", "w", "y", "y"]))
121+
tm.assert_frame_equal(result, expected)
122+
123+
result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0])
124+
expected = Series(values, name="b", index=Index(["w", "w", "y", "y"]))
125+
tm.assert_series_equal(result, expected)

pandas/tests/groupby/test_whitelist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def test_tab_completion(mframe):
328328
"rolling",
329329
"expanding",
330330
"pipe",
331+
"sample",
331332
}
332333
assert results == expected
333334

0 commit comments

Comments
 (0)