Skip to content

add train_test_split function to DataFrame #6687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
import warnings
import types
import random

from numpy import nan as NA
import numpy as np
Expand Down Expand Up @@ -3836,6 +3837,66 @@ def pretty_name(x):
return self._constructor(lmap(list, zip(*destat)),
index=destat_columns, columns=numdata.columns)

def train_test_split(self, test_rate=0.25, random_state=None):
"""Split pandas DataFrame into random train and test subsets
Parameters
----------
* df : pandas DataFrame

test_rate : float or None (default is None)
If float, should be between 0.0 and 1.0 and represent the
proportion of the dataset to include in the test split.
If train size is also None, test size is set to 0.25.

random_state : int or RandomState
Pseudo-random number generator state used for random sampling. use random.seed

Returns
-------
splitting : list of DataFrame, length=2
List containing train-test split of input Dataframe.

Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> a = range(10)
>>> b = range(10)
>>> df = pd.DataFrame({'a' : a, 'b' : b})
>>> df_train, df_test = df.train_test_split()
>>> a_train
a b
1 1 1
8 8 8

[2 rows x 2 columns]
>>> b_train
a b
0 0 0
2 2 2
3 3 3
4 4 4
5 5 5
6 6 6
7 7 7
9 9 9

[8 rows x 2 columns]
"""

if test_rate is None:
test_rate = 0.25

test_size = int(len(self) * test_rate)

if random_state:
random.seed(random_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In pandas.utils.testing there's a class RNGContext that's a context manager for the random state. I'd use that to return the user to their original state once you're done with the splitting, rather than modifying the global state.

test_index = random.sample(self.index, test_size)
df_train = self.ix[test_index]
df_test = self.ix[[i for i in self.index if i not in test_index]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do boolean indexing (as @jtratner mentions) self.ix[~test_index]

splitting = [df_train, df_test]
return splitting

#----------------------------------------------------------------------
# ndarray-like stats methods

Expand Down