Skip to content

Commit 4712e4b

Browse files
authored
Merge pull request #1 from echozzy629/ziyi_issue_31175
add a new feature sample() into groupby
2 parents 2e1f5b0 + 7c24ed8 commit 4712e4b

File tree

1 file changed

+147
-1
lines changed

1 file changed

+147
-1
lines changed

pandas/core/groupby/groupby.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,152 @@ def describe(self, **kwargs):
14361436
return result.T
14371437
return result.unstack()
14381438

1439+
def sample(groupby_result, size=None, frac=None, replace=False, weights=None):
1440+
"""
1441+
Returns a random sample in dictionary.
1442+
1443+
Parameters
1444+
----------
1445+
n : int, optional
1446+
Number of items from axis to return. Cannot be used with `frac`.
1447+
Default = 1 if `frac` = None.
1448+
frac : float, optional
1449+
Fraction of items to return. Cannot be used with `size`.
1450+
replace : boolean, optional
1451+
Sample with or without replacement. Default = False.
1452+
weights : list of float, optional
1453+
Default 'None' results in equal probability weighting.
1454+
Index values in sampled object not in weights will be assigned
1455+
weights of zero.
1456+
If weights do not sum to 1, they will be normalized to sum to 1.
1457+
Missing values in the weights column will be treated as zero.
1458+
inf and -inf values not allowed.
1459+
1460+
Returns
1461+
-------
1462+
A new object of same type as caller.
1463+
1464+
Examples
1465+
--------
1466+
Generate an example ``DataFrame``:
1467+
1468+
>>> df = pd.DataFrame([['Male', 1], ['Female', 3], ['Female', 2], ['Other', 1]], columns=['gender', 'feature'])
1469+
gender feature
1470+
0 Male 1
1471+
1 Female 3
1472+
2 Female 2
1473+
3 Other 1
1474+
1475+
>>> grouped_df = df.groupby('gender')
1476+
<pandas.core.groupby.generic.DataFrameGroupBy object at 0x1034409b0>
1477+
1478+
Next extract a random sample:
1479+
1480+
2 random elements sample:
1481+
1482+
>>> sample=groupby.sample(size = 2)
1483+
{'Female': Int64Index([1, 2], dtype='int64'), 'Male': Int64Index([0], dtype='int64')}
1484+
1485+
2 random elements samplt with given weights:
1486+
>>> sample=groupby.sample(size = 2, weights = [0.1,0.1,0.2])
1487+
{'Male': Int64Index([0], dtype='int64'), 'Other': Int64Index([3], dtype='int64')}
1488+
1489+
A random 40% with replacement:
1490+
>>> sample=groupby.sample(frac = 0.4, replace = True)
1491+
{'Male': Int64Index([0], dtype='int64')}
1492+
1493+
"""
1494+
groups_dictionary=groupby_result.groups
1495+
1496+
#check size and frac:
1497+
#if no input sieze and no input frac: default sto size = 1
1498+
if(size == None and frac == None):
1499+
final_size=1
1500+
1501+
#if no input size but have the frac:
1502+
elif(size == None and frac is not None):
1503+
final_size=int(round(frac*len(groups_dictionary)))
1504+
1505+
#if no input frac but have the size:
1506+
elif(size is not None and frac is None and size % 1 ==0):
1507+
final_size=size
1508+
elif(size is not None and frac is None and size % 1 !=0):
1509+
raise ValueError("Only integers accepted as size value")
1510+
#if both enter size and frac: error
1511+
elif(size is not None and frac is not None):
1512+
raise ValueError('Please enter a value for `frac` OR `size`, not both')
1513+
1514+
print("For the given group, the size of sample is %d" %final_size)
1515+
1516+
#errors:
1517+
if(size is not None):
1518+
#1. non-integer size error:
1519+
#if(size%1 !=0):
1520+
# raise ValueError("Only integers accepted as size value")
1521+
1522+
#2. negative size error:
1523+
if size < 0:
1524+
raise ValueError("A negative number of sample size requested. Please provide a positive value.")
1525+
1526+
#3. overflow error:
1527+
maximum_size=len(groups_dictionary)
1528+
if size > maximum_size:
1529+
raise ValueError("The size of requested sample is overflow. Please provide the value of size in range.")
1530+
1531+
if(frac is not None):
1532+
if(frac >1):
1533+
raise ValueError("Only float between 0 an 1 accepted as frac value")
1534+
1535+
1536+
#edge warning:
1537+
if(size==0 or frac ==0):
1538+
raise Warning("Random sample is empty: the input sample size is 0")
1539+
if(size==len(groups_dictionary) or frac ==1):
1540+
raise Warning("Random sample equals to the given groupbt: the inplut size is the same as the size of the input group")
1541+
1542+
if weights:
1543+
#weights is a list
1544+
if(len(weights) != len(groups_dictionary.keys())):
1545+
raise ValueError("Weights and axis to be sampled must be the same length")
1546+
for w in weights:
1547+
#if(w == np.inf() or w == -np.inf()):
1548+
# raise ValueError("Weight vectr may not inclue `inf` values")
1549+
if(w < 0):
1550+
raise ValueError("Weight vector may no include nagative value")
1551+
# If has nan, set to zero:
1552+
if(w==np.nan):
1553+
w=0
1554+
1555+
# Renormalize if don's sum to 1:
1556+
if(sum(weights)!=1):
1557+
if(sum(weights)!=0):
1558+
new_weights=[]
1559+
for w in weights:
1560+
new_w = w / sum(weights)
1561+
new_weights.append(new_w)
1562+
weights=new_weights
1563+
else:
1564+
raise ValueError("Invalid weights: weights sum to zero")
1565+
1566+
#random sampling:
1567+
#sample=random.sample(groups_dictionary.keys(),final_size, replace=replace)
1568+
dictionary_keys=list(groups_dictionary.keys())
1569+
num_of_keys=len(dictionary_keys)
1570+
sample=np.random.choice(num_of_keys,size=final_size,replace=replace,p=weights)
1571+
sample_keys=[]
1572+
for i in sample:
1573+
sample_keys.append(dictionary_keys[i])
1574+
sample_dictionary = {key: value for key, value in groups_dictionary.items() if key in sample_keys}
1575+
return sample_dictionary
1576+
1577+
1578+
1579+
1580+
1581+
1582+
1583+
1584+
14391585
def resample(self, rule, *args, **kwargs):
14401586
"""
14411587
Provide resampling when using a TimeGrouper.
@@ -2322,7 +2468,7 @@ def shift(self, periods=1, freq=None, axis=0, fill_value=None):
23222468
See Also
23232469
--------
23242470
Index.shift : Shift values of Index.
2325-
tshift : Shift the time index, using the index’s frequency
2471+
tshift : Shift the time index, using the index’s frequency
23262472
if available.
23272473
"""
23282474
if freq is not None or axis != 0 or not isna(fill_value):

0 commit comments

Comments
 (0)