diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 6220e899ae..a6553d13dc 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -17,7 +17,6 @@ https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection.""" -import typing from typing import cast, List, Union from bigframes.ml import utils @@ -87,7 +86,7 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra merged_df = df.join(stratify.to_frame(), how="outer") train_dfs, test_dfs = [], [] - uniq = stratify.unique() + uniq = stratify.value_counts().index for value in uniq: cur = merged_df[merged_df["bigframes_stratify_col"] == value] train, test = train_test_split( @@ -107,26 +106,20 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra ) return [train_df, test_df] + joined_df = dfs[0] + for df in dfs[1:]: + joined_df = joined_df.join(df, how="outer") if stratify is None: - split_dfs = dfs[0]._split( + joined_df_train, joined_df_test = joined_df._split( fracs=(train_size, test_size), random_state=random_state ) else: - split_dfs = _stratify_split(dfs[0], stratify) - train_index = split_dfs[0].index - test_index = split_dfs[1].index - - split_dfs += typing.cast( - List[bpd.DataFrame], - [df.loc[index] for df in dfs[1:] for index in (train_index, test_index)], - ) - - # convert back to Series. - results: List[Union[bpd.DataFrame, bpd.Series]] = [] - for i, array in enumerate(arrays): - if isinstance(array, bpd.Series): - results += utils.convert_to_series(split_dfs[2 * i], split_dfs[2 * i + 1]) - else: - results += (split_dfs[2 * i], split_dfs[2 * i + 1]) + joined_df_train, joined_df_test = _stratify_split(joined_df, stratify) + + results = [] + for array in arrays: + columns = array.name if isinstance(array, bpd.Series) else array.columns + results.append(joined_df_train[columns]) + results.append(joined_df_test[columns]) return results diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 05ff80dc33..b382a5593c 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -606,6 +606,14 @@ def penguins_df_default_index( return session.read_gbq(penguins_table_id) +@pytest.fixture(scope="session") +def penguins_df_null_index( + penguins_table_id: str, unordered_session: bigframes.Session +) -> bigframes.dataframe.DataFrame: + """DataFrame pointing at test data.""" + return unordered_session.read_gbq(penguins_table_id) + + @pytest.fixture(scope="session") def time_series_df_default_index( time_series_table_id: str, session: bigframes.Session diff --git a/tests/system/small/ml/test_model_selection.py b/tests/system/small/ml/test_model_selection.py index ea9220feb4..47529565b7 100644 --- a/tests/system/small/ml/test_model_selection.py +++ b/tests/system/small/ml/test_model_selection.py @@ -19,15 +19,20 @@ import bigframes.pandas as bpd -def test_train_test_split_default_correct_shape(penguins_df_default_index): - X = penguins_df_default_index[ +@pytest.mark.parametrize( + "df_fixture", + ("penguins_df_default_index", "penguins_df_null_index"), +) +def test_train_test_split_default_correct_shape(df_fixture, request): + df = request.getfixturevalue(df_fixture) + X = df[ [ "species", "island", "culmen_length_mm", ] ] - y = penguins_df_default_index[["body_mass_g"]] + y = df[["body_mass_g"]] X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y) # even though the default seed is random, it should always result in this shape @@ -236,17 +241,18 @@ def test_train_test_split_value_error(penguins_df_default_index, train_size, tes ) -def test_train_test_split_stratify(penguins_df_default_index): - X = penguins_df_default_index[ - [ - "species", - "island", - "culmen_length_mm", - ] - ] - y = penguins_df_default_index[["species"]] +@pytest.mark.parametrize( + "df_fixture", + ("penguins_df_default_index", "penguins_df_null_index"), +) +def test_train_test_split_stratify(df_fixture, request): + df = request.getfixturevalue(df_fixture) + X = df[["species", "island", "culmen_length_mm",]].rename( + columns={"species": "x_species"} + ) # Keep "species" col just for easy checking. Rename to avoid conflicts. + y = df[["species"]] X_train, X_test, y_train, y_test = model_selection.train_test_split( - X, y, stratify=penguins_df_default_index["species"] + X, y, stratify=df["species"] ) # Original distribution is [152, 124, 68]. All the categories follow 75/25 split @@ -277,12 +283,12 @@ def test_train_test_split_stratify(penguins_df_default_index): name="count", ) pd.testing.assert_series_equal( - X_train["species"].value_counts().to_pandas(), + X_train["x_species"].rename("species").value_counts().to_pandas(), train_counts, check_index_type=False, ) pd.testing.assert_series_equal( - X_test["species"].value_counts().to_pandas(), + X_test["x_species"].rename("species").value_counts().to_pandas(), test_counts, check_index_type=False, )