diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index aa71fd68086fb..2cca7cf6697e5 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2476,3 +2476,29 @@ def groupby(obj, by, **kwds): raise TypeError("invalid type: {}".format(obj)) return klass(obj, by, **kwds) + +def get_train_test_stratified_by_time_and_col(df, stratify_split_col_name, time_sort_col_name=None, test_size=0.2, ascending=True, get_shuffled=True, random_state=0): + """ + stratify_split_col_name: The column name which has to be split in train and test in a stratified way. + time_sort_col_name: The column name used for sorting rows by time. + test_size: Test size + ascending: For sorting time_sort_col_name by ascending or descending. + """ + df.reset_index(inplace=True) + if time_sort_col_name: + df = df.sort_values(time_sort_col_name, ascending=ascending) + train_size = 1 - test_size + train_indexes, test_indexes = [], [] + for col_val, group in df.groupby(stratify_split_col_name): + count = math.floor(group.shape[0]*train_size) + g_train_index = group[:count].index + g_test_index = group[count:].index + train_indexes.extend(g_train_index) + test_indexes.extend(g_test_index) + print('Train length:', str(len(train_indexes))) + print('Test length:', str(len(test_indexes))) + df_train, df_test = df.iloc[train_indexes], df.iloc[test_indexes] + if get_shuffled: + return df_train.sample(frac=1, random_state=random_state), df_test.sample(frac=1, random_state=random_state) + else: + return df_train, df_test