Skip to content

Commit 3076a13

Browse files
authored
chore: add overload to to_datetime and cleanup type errors in tests/system/small/test_pandas.py (#766)
* chore: add overload to to_datetime and cleanup type errors in tests/unit/test_pandas.py * update import
1 parent 7e8296d commit 3076a13

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

bigframes/pandas/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,32 @@ def read_gbq_function(function_name: str):
699699
read_gbq_function.__doc__ = inspect.getdoc(bigframes.session.Session.read_gbq_function)
700700

701701

702+
@typing.overload
703+
def to_datetime(
704+
arg: vendored_pandas_datetimes.local_scalars,
705+
*,
706+
utc: bool = False,
707+
format: Optional[str] = None,
708+
unit: Optional[str] = None,
709+
) -> Union[pandas.Timestamp, datetime]:
710+
...
711+
712+
713+
@typing.overload
714+
def to_datetime(
715+
arg: Union[
716+
vendored_pandas_datetimes.local_iterables,
717+
bigframes.series.Series,
718+
bigframes.dataframe.DataFrame,
719+
],
720+
*,
721+
utc: bool = False,
722+
format: Optional[str] = None,
723+
unit: Optional[str] = None,
724+
) -> bigframes.series.Series:
725+
...
726+
727+
702728
def to_datetime(
703729
arg: Union[
704730
vendored_pandas_datetimes.local_scalars,

tests/system/small/test_pandas.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,8 @@ def test_cut(scalars_dfs):
394394

395395
# make sure the result is a supported dtype
396396
assert bf_result.dtype == bpd.Int64Dtype()
397-
398-
# TODO(b/340884971): fix type error
399-
bf_result = bf_result.to_pandas() # type: ignore
400397
pd_result = pd_result.astype("Int64")
401-
# TODO(b/340884971): fix type error
402-
pd.testing.assert_series_equal(bf_result, pd_result) # type: ignore
398+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
403399

404400

405401
def test_cut_default_labels(scalars_dfs):
@@ -529,13 +525,9 @@ def test_qcut(scalars_dfs, q):
529525
scalars_pandas_df["float64_col"], q, labels=False, duplicates="drop"
530526
)
531527
bf_result = bpd.qcut(scalars_df["float64_col"], q, labels=False, duplicates="drop")
532-
533-
# TODO(b/340884971): fix type error
534-
bf_result = bf_result.to_pandas() # type: ignore
535528
pd_result = pd_result.astype("Int64")
536529

537-
# TODO(b/340884971): fix type error
538-
pd.testing.assert_series_equal(bf_result, pd_result) # type: ignore
530+
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
539531

540532

541533
@pytest.mark.parametrize(
@@ -572,9 +564,8 @@ def test_to_datetime_scalar(arg, utc, unit, format):
572564
],
573565
)
574566
def test_to_datetime_iterable(arg, utc, unit, format):
575-
# TODO(b/340884971): fix type error
576567
bf_result = (
577-
bpd.to_datetime(arg, utc=utc, unit=unit, format=format) # type: ignore
568+
bpd.to_datetime(arg, utc=utc, unit=unit, format=format)
578569
.to_pandas()
579570
.astype("datetime64[ns, UTC]" if utc else "datetime64[ns]")
580571
)
@@ -589,9 +580,8 @@ def test_to_datetime_iterable(arg, utc, unit, format):
589580
def test_to_datetime_series(scalars_dfs):
590581
scalars_df, scalars_pandas_df = scalars_dfs
591582
col = "int64_too"
592-
# TODO(b/340884971): fix type error
593583
bf_result = (
594-
bpd.to_datetime(scalars_df[col], unit="s").to_pandas().astype("datetime64[s]") # type: ignore
584+
bpd.to_datetime(scalars_df[col], unit="s").to_pandas().astype("datetime64[s]")
595585
)
596586
pd_result = pd.Series(pd.to_datetime(scalars_pandas_df[col], unit="s"))
597587
pd.testing.assert_series_equal(
@@ -614,8 +604,7 @@ def test_to_datetime_series(scalars_dfs):
614604
],
615605
)
616606
def test_to_datetime_unit_param(arg, unit):
617-
# TODO(b/340884971): fix type error
618-
bf_result = bpd.to_datetime(arg, unit=unit).to_pandas().astype("datetime64[ns]") # type: ignore
607+
bf_result = bpd.to_datetime(arg, unit=unit).to_pandas().astype("datetime64[ns]")
619608
pd_result = pd.Series(pd.to_datetime(arg, unit=unit)).dt.floor("us")
620609
pd.testing.assert_series_equal(
621610
bf_result, pd_result, check_index_type=False, check_names=False
@@ -632,9 +621,8 @@ def test_to_datetime_unit_param(arg, unit):
632621
],
633622
)
634623
def test_to_datetime_format_param(arg, utc, format):
635-
# TODO(b/340884971): fix type error
636624
bf_result = (
637-
bpd.to_datetime(arg, utc=utc, format=format) # type: ignore
625+
bpd.to_datetime(arg, utc=utc, format=format)
638626
.to_pandas()
639627
.astype("datetime64[ns, UTC]" if utc else "datetime64[ns]")
640628
)
@@ -686,9 +674,8 @@ def test_to_datetime_format_param(arg, utc, format):
686674
],
687675
)
688676
def test_to_datetime_string_inputs(arg, utc, output_in_utc, format):
689-
# TODO(b/340884971): fix type error
690677
bf_result = (
691-
bpd.to_datetime(arg, utc=utc, format=format) # type: ignore
678+
bpd.to_datetime(arg, utc=utc, format=format)
692679
.to_pandas()
693680
.astype("datetime64[ns, UTC]" if output_in_utc else "datetime64[ns]")
694681
)
@@ -730,9 +717,8 @@ def test_to_datetime_string_inputs(arg, utc, output_in_utc, format):
730717
],
731718
)
732719
def test_to_datetime_timestamp_inputs(arg, utc, output_in_utc):
733-
# TODO(b/340884971): fix type error
734720
bf_result = (
735-
bpd.to_datetime(arg, utc=utc) # type: ignore
721+
bpd.to_datetime(arg, utc=utc)
736722
.to_pandas()
737723
.astype("datetime64[ns, UTC]" if output_in_utc else "datetime64[ns]")
738724
)

third_party/bigframes_vendored/pandas/core/tools/datetimes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Contains code from https://github.com/pandas-dev/pandas/blob/main/pandas/core/tools/datetimes.py
22

33
from datetime import datetime
4-
from typing import Iterable, Mapping, Union
4+
from typing import List, Mapping, Tuple, Union
55

66
import pandas as pd
77

88
from bigframes import constants, series
99

1010
local_scalars = Union[int, float, str, datetime]
11-
local_iterables = Union[Iterable, pd.Series, pd.DataFrame, Mapping]
11+
local_iterables = Union[List, Tuple, pd.Series, pd.DataFrame, Mapping]
1212

1313

1414
def to_datetime(

0 commit comments

Comments
 (0)