diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index abab9fd268..77a1af723f 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2174,6 +2174,18 @@ def add_suffix(self, suffix: str, axis: int | str | None = None) -> DataFrame: axis = 1 if axis is None else axis return DataFrame(self._get_block().add_suffix(suffix, axis)) + def take( + self, indices: typing.Sequence[int], axis: int | str | None = 0, **kwargs + ) -> DataFrame: + if not utils.is_list_like(indices): + raise ValueError("indices should be a list-like object.") + if axis == 0 or axis == "index": + return self.iloc[indices] + elif axis == 1 or axis == "columns": + return self.iloc[:, indices] + else: + raise ValueError(f"No axis named {axis} for object type DataFrame") + def filter( self, items: typing.Optional[typing.Iterable] = None, diff --git a/bigframes/series.py b/bigframes/series.py index 34ac3c3de9..ae55faeb74 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1631,6 +1631,13 @@ def add_prefix(self, prefix: str, axis: int | str | None = None) -> Series: def add_suffix(self, suffix: str, axis: int | str | None = None) -> Series: return Series(self._get_block().add_suffix(suffix)) + def take( + self, indices: typing.Sequence[int], axis: int | str | None = 0, **kwargs + ) -> Series: + if not utils.is_list_like(indices): + raise ValueError("indices should be a list-like object.") + return typing.cast(Series, self.iloc[indices]) + def filter( self, items: typing.Optional[typing.Iterable] = None, diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index c2e4a1c8ad..8cc3be1577 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -807,6 +807,24 @@ def test_get_df_column_name_duplicate(scalars_dfs): pd.testing.assert_index_equal(bf_result.columns, pd_result.columns) +@pytest.mark.parametrize( + ("indices", "axis"), + [ + ([1, 3, 5], 0), + ([2, 4, 6], 1), + ([1, -3, -5, -6], "index"), + ([-2, -4, -6], "columns"), + ], +) +def test_take_df(scalars_dfs, indices, axis): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.take(indices, axis=axis).to_pandas() + pd_result = scalars_pandas_df.take(indices, axis=axis) + + assert_pandas_df_equal(bf_result, pd_result) + + def test_filter_df(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index ef544b0a0b..5f4e40c96d 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -1543,6 +1543,23 @@ def test_indexing_using_selected_series(scalars_dfs): ) +@pytest.mark.parametrize( + ("indices"), + [ + ([1, 3, 5]), + ([5, -3, -5, -6]), + ([-2, -4, -6]), + ], +) +def test_take(scalars_dfs, indices): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.take(indices).to_pandas() + pd_result = scalars_pandas_df.take(indices) + + assert_pandas_df_equal(bf_result, pd_result) + + def test_nested_filter(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs string_col = scalars_df["string_col"] diff --git a/third_party/bigframes_vendored/pandas/core/generic.py b/third_party/bigframes_vendored/pandas/core/generic.py index ee35bfa429..8dd43fd8da 100644 --- a/third_party/bigframes_vendored/pandas/core/generic.py +++ b/third_party/bigframes_vendored/pandas/core/generic.py @@ -910,6 +910,29 @@ def notna(self) -> NDFrame: notnull = notna + def take(self, indices, axis=0, **kwargs) -> NDFrame: + """Return the elements in the given positional indices along an axis. + + This means that we are not indexing according to actual values in the index + attribute of the object. We are indexing according to the actual position of + the element in the object. + + Args: + indices (list-like): + An array of ints indicating which positions to take. + axis ({0 or 'index', 1 or 'columns', None}, default 0): + The axis on which to select elements. 0 means that we are selecting rows, + 1 means that we are selecting columns. For Series this parameter is + unused and defaults to 0. + **kwargs: + For compatibility with numpy.take(). Has no effect on the output. + + Returns: + bigframes.pandas.DataFrame or bigframes.pandas.Series: + Same type as input object. + """ + raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE) + def filter( self, items=None,