diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 074b17762f..178d698f8d 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -26,7 +26,7 @@ import itertools import random import typing -from typing import Iterable, List, Mapping, Optional, Sequence, Tuple +from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple import warnings import google.cloud.bigquery as bigquery @@ -555,7 +555,7 @@ def _downsample( block = self._split( fracs=(fraction,), random_state=random_state, - preserve_order=True, + sort=False, )[0] return block else: @@ -571,7 +571,7 @@ def _split( fracs: Iterable[float] = (), *, random_state: Optional[int] = None, - preserve_order: Optional[bool] = False, + sort: Optional[bool | Literal["random"]] = "random", ) -> List[Block]: """Internal function to support splitting Block to multiple parts along index axis. @@ -623,7 +623,18 @@ def _split( typing.cast(Block, block.slice(start=lower, stop=upper)) for lower, upper in intervals ] - if preserve_order: + + if sort is True: + sliced_blocks = [ + sliced_block.order_by( + [ + ordering.OrderingColumnReference(idx_col) + for idx_col in sliced_block.index_columns + ] + ) + for sliced_block in sliced_blocks + ] + elif sort is False: sliced_blocks = [ sliced_block.order_by([ordering.OrderingColumnReference(ordering_col)]) for sliced_block in sliced_blocks diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index ee7d78d984..4e447c547f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2504,6 +2504,7 @@ def sample( frac: Optional[float] = None, *, random_state: Optional[int] = None, + sort: Optional[bool | Literal["random"]] = "random", ) -> DataFrame: if n is not None and frac is not None: raise ValueError("Only one of 'n' or 'frac' parameter can be specified.") @@ -2511,7 +2512,9 @@ def sample( ns = (n,) if n is not None else () fracs = (frac,) if frac is not None else () return DataFrame( - self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0] + self._block._split( + ns=ns, fracs=fracs, random_state=random_state, sort=sort + )[0] ) def _split( diff --git a/bigframes/operations/_matplotlib/core.py b/bigframes/operations/_matplotlib/core.py index 5c9d771f61..7cbeb3df4f 100644 --- a/bigframes/operations/_matplotlib/core.py +++ b/bigframes/operations/_matplotlib/core.py @@ -47,11 +47,11 @@ def _compute_plot_data(self, data): # TODO: Cache the sampling data in the PlotAccessor. sampling_n = self.kwargs.pop("sampling_n", 100) sampling_random_state = self.kwargs.pop("sampling_random_state", 0) - return ( - data.sample(n=sampling_n, random_state=sampling_random_state) - .to_pandas() - .sort_index() - ) + return data.sample( + n=sampling_n, + random_state=sampling_random_state, + sort=False, + ).to_pandas() class LinePlot(SamplingPlot): diff --git a/bigframes/series.py b/bigframes/series.py index f1eabc18fe..5f6cfe9893 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -22,7 +22,7 @@ import os import textwrap import typing -from typing import Any, Mapping, Optional, Tuple, Union +from typing import Any, Literal, Mapping, Optional, Tuple, Union import bigframes_vendored.pandas.core.series as vendored_pandas_series import google.cloud.bigquery as bigquery @@ -1535,6 +1535,7 @@ def sample( frac: Optional[float] = None, *, random_state: Optional[int] = None, + sort: Optional[bool | Literal["random"]] = "random", ) -> Series: if n is not None and frac is not None: raise ValueError("Only one of 'n' or 'frac' parameter can be specified.") @@ -1542,7 +1543,9 @@ def sample( ns = (n,) if n is not None else () fracs = (frac,) if frac is not None else () return Series( - self._block._split(ns=ns, fracs=fracs, random_state=random_state)[0] + self._block._split( + ns=ns, fracs=fracs, random_state=random_state, sort=sort + )[0] ) def __array_ufunc__( diff --git a/tests/system/small/operations/test_plotting.py b/tests/system/small/operations/test_plotting.py index 876c8f7d04..47491cdada 100644 --- a/tests/system/small/operations/test_plotting.py +++ b/tests/system/small/operations/test_plotting.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pandas as pd import pandas._testing as tm import pytest @@ -235,6 +236,18 @@ def test_sampling_plot_args_random_state(): tm.assert_almost_equal(ax_0.lines[0].get_data()[1], ax_2.lines[0].get_data()[1]) +def test_sampling_preserve_ordering(): + df = bpd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0]) + pd_df = pd.DataFrame([0.0, 1.0, 2.0, 3.0, 4.0], index=[1, 3, 4, 2, 0]) + ax = df.plot.line() + pd_ax = pd_df.plot.line() + tm.assert_almost_equal(ax.get_xticks(), pd_ax.get_xticks()) + tm.assert_almost_equal(ax.get_yticks(), pd_ax.get_yticks()) + for line, pd_line in zip(ax.lines, pd_ax.lines): + # Compare y coordinates between the lines + tm.assert_almost_equal(line.get_data()[1], pd_line.get_data()[1]) + + @pytest.mark.parametrize( ("kind", "col_names", "kwargs"), [ @@ -251,7 +264,7 @@ def test_sampling_plot_args_random_state(): marks=pytest.mark.xfail(raises=ValueError), ), pytest.param( - "uknown", + "bar", ["int64_col", "int64_too"], {}, marks=pytest.mark.xfail(raises=NotImplementedError), diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 54df35c333..3b6cd8c05f 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -3049,6 +3049,28 @@ def test_sample_raises_value_error(scalars_dfs): scalars_df.sample(frac=0.5, n=4) +def test_sample_args_sort(scalars_dfs): + scalars_df, _ = scalars_dfs + index = [4, 3, 2, 5, 1, 0] + scalars_df = scalars_df.iloc[index] + + kwargs = {"frac": 1.0, "random_state": 333} + + df = scalars_df.sample(**kwargs).to_pandas() + assert df.index.values != index + assert df.index.values != sorted(index) + + df = scalars_df.sample(sort="random", **kwargs).to_pandas() + assert df.index.values != index + assert df.index.values != sorted(index) + + df = scalars_df.sample(sort=True, **kwargs).to_pandas() + assert df.index.values == sorted(index) + + df = scalars_df.sample(sort=False, **kwargs).to_pandas() + assert df.index.values == index + + @pytest.mark.parametrize( ("axis",), [ diff --git a/third_party/bigframes_vendored/pandas/core/generic.py b/third_party/bigframes_vendored/pandas/core/generic.py index 4f91c1b19a..d1cf55c95b 100644 --- a/third_party/bigframes_vendored/pandas/core/generic.py +++ b/third_party/bigframes_vendored/pandas/core/generic.py @@ -472,6 +472,7 @@ def sample( frac: Optional[float] = None, *, random_state: Optional[int] = None, + sort: Optional[bool | Literal["random"]] = "random", ): """Return a random sample of items from an axis of object. @@ -530,6 +531,12 @@ def sample( Fraction of axis items to return. Cannot be used with `n`. random_state (Optional[int], default None): Seed for random number generator. + sort (Optional[bool|Literal["random"]], default "random"): + + - 'random' (default): No specific ordering will be applied after + sampling. + - 'True' : Index columns will determine the sample's order. + - 'False': The sample will retain the original object's order. Returns: A new object of same type as caller containing `n` items randomly