Skip to content

Commit 07222bf

Browse files
feat: Add __contains__ to Index, Series, DataFrame (#1899)
1 parent 1aa7950 commit 07222bf

File tree

8 files changed

+102
-1
lines changed

8 files changed

+102
-1
lines changed

bigframes/core/indexes/base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
from __future__ import annotations
1818

19+
import functools
1920
import typing
20-
from typing import Hashable, Literal, Optional, overload, Sequence, Union
21+
from typing import cast, Hashable, Literal, Optional, overload, Sequence, Union
2122

2223
import bigframes_vendored.constants as constants
2324
import bigframes_vendored.pandas.core.indexes.base as vendored_pandas_index
@@ -529,6 +530,29 @@ def isin(self, values) -> Index:
529530
)
530531
).fillna(value=False)
531532

533+
def __contains__(self, key) -> bool:
534+
hash(key) # to throw for unhashable values
535+
if self.nlevels == 0:
536+
return False
537+
538+
if (not isinstance(key, tuple)) or (self.nlevels == 1):
539+
key = (key,)
540+
541+
match_exprs = []
542+
for key_part, index_col, dtype in zip(
543+
key, self._block.index_columns, self._block.index.dtypes
544+
):
545+
key_type = bigframes.dtypes.is_compatible(key_part, dtype)
546+
if key_type is None:
547+
return False
548+
key_expr = ex.const(key_part, key_type)
549+
match_expr = ops.eq_null_match_op.as_expr(ex.deref(index_col), key_expr)
550+
match_exprs.append(match_expr)
551+
552+
match_expr_final = functools.reduce(ops.and_op.as_expr, match_exprs)
553+
block, match_col = self._block.project_expr(match_expr_final)
554+
return cast(bool, block.get_stat(match_col, agg_ops.AnyOp()))
555+
532556
def _apply_unary_expr(
533557
self,
534558
op: ex.Expression,

bigframes/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ def __len__(self):
374374
def __iter__(self):
375375
return iter(self.columns)
376376

377+
def __contains__(self, key) -> bool:
378+
return key in self.columns
379+
377380
def astype(
378381
self,
379382
dtype: Union[

bigframes/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def __iter__(self) -> typing.Iterator:
257257
map(lambda x: x.squeeze(axis=1), self._block.to_pandas_batches())
258258
)
259259

260+
def __contains__(self, key) -> bool:
261+
return key in self.index
262+
260263
def copy(self) -> Series:
261264
return Series(self._block)
262265

tests/system/small/test_dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4451,6 +4451,22 @@ def test_df___array__(scalars_df_index, scalars_pandas_df_index):
44514451
)
44524452

44534453

4454+
@pytest.mark.parametrize(
4455+
("key",),
4456+
[
4457+
("hello",),
4458+
(2,),
4459+
("int64_col",),
4460+
(None,),
4461+
],
4462+
)
4463+
def test_df_contains(scalars_df_index, scalars_pandas_df_index, key):
4464+
bf_result = key in scalars_df_index
4465+
pd_result = key in scalars_pandas_df_index
4466+
4467+
assert bf_result == pd_result
4468+
4469+
44544470
def test_df_getattr_attribute_error_when_pandas_has(scalars_df_index):
44554471
# swapaxes is implemented in pandas but not in bigframes
44564472
with pytest.raises(AttributeError):

tests/system/small/test_index.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,18 @@ def test_index_drop_duplicates(scalars_df_index, scalars_pandas_df_index, keep):
398398
)
399399

400400

401+
@pytest.mark.parametrize(
402+
("key",),
403+
[("hello",), (2,), (123123321,), (2.0,), (False,), ((2,),), (pd.NA,)],
404+
)
405+
def test_index_contains(scalars_df_index, scalars_pandas_df_index, key):
406+
col_name = "int64_col"
407+
bf_result = key in scalars_df_index.set_index(col_name).index
408+
pd_result = key in scalars_pandas_df_index.set_index(col_name).index
409+
410+
assert bf_result == pd_result
411+
412+
401413
def test_index_isin_list(scalars_df_index, scalars_pandas_df_index):
402414
col_name = "int64_col"
403415
bf_series = (

tests/system/small/test_multiindex.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,3 +1388,26 @@ def test_column_multi_index_w_na_stack(scalars_df_index, scalars_pandas_df_index
13881388
# Pandas produces pd.NA, where bq dataframes produces NaN
13891389
pd_result["c"] = pd_result["c"].replace(pandas.NA, np.nan)
13901390
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
1391+
1392+
1393+
@pytest.mark.parametrize(
1394+
("key",),
1395+
[
1396+
("hello",),
1397+
(2,),
1398+
(123123321,),
1399+
(2.0,),
1400+
(pandas.NA,),
1401+
(False,),
1402+
((2,),),
1403+
((2, False),),
1404+
((2.0, False),),
1405+
((2, True),),
1406+
],
1407+
)
1408+
def test_multi_index_contains(scalars_df_index, scalars_pandas_df_index, key):
1409+
col_name = ["int64_col", "bool_col"]
1410+
bf_result = key in scalars_df_index.set_index(col_name).index
1411+
pd_result = key in scalars_pandas_df_index.set_index(col_name).index
1412+
1413+
assert bf_result == pd_result

tests/system/small/test_null_index.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,7 @@ def test_null_index_index_property(scalars_df_null_index):
396396
def test_null_index_transpose(scalars_df_null_index):
397397
with pytest.raises(bigframes.exceptions.NullIndexError):
398398
_ = scalars_df_null_index.T
399+
400+
401+
def test_null_index_contains(scalars_df_null_index):
402+
assert 3 not in scalars_df_null_index

tests/system/small/test_series.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,22 @@ def test_series_get_column_default(scalars_dfs):
424424
assert result == "default_val"
425425

426426

427+
@pytest.mark.parametrize(
428+
("key",),
429+
[
430+
("hello",),
431+
(2,),
432+
("int64_col",),
433+
(None,),
434+
],
435+
)
436+
def test_series_contains(scalars_df_index, scalars_pandas_df_index, key):
437+
bf_result = key in scalars_df_index["int64_col"]
438+
pd_result = key in scalars_pandas_df_index["int64_col"]
439+
440+
assert bf_result == pd_result
441+
442+
427443
def test_series_equals_identical(scalars_df_index, scalars_pandas_df_index):
428444
bf_result = scalars_df_index.int64_col.equals(scalars_df_index.int64_col)
429445
pd_result = scalars_pandas_df_index.int64_col.equals(

0 commit comments

Comments
 (0)