diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index bce2bf0..f1424fb 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -17,6 +17,7 @@ import datetime import re +from typing import Union import numpy import packaging.version @@ -29,6 +30,7 @@ import pandas.core.dtypes.generic import pandas.core.nanops import pyarrow +import pyarrow.compute from db_dtypes.version import __version__ from db_dtypes import core @@ -36,6 +38,8 @@ date_dtype_name = "dbdate" time_dtype_name = "dbtime" +_EPOCH = datetime.datetime(1970, 1, 1) +_NPEPOCH = numpy.datetime64(_EPOCH) pandas_release = packaging.version.parse(pandas.__version__).release @@ -52,6 +56,33 @@ class TimeDtype(core.BaseDatetimeDtype): def construct_array_type(self): return TimeArray + @staticmethod + def __from_arrow__( + array: Union[pyarrow.Array, pyarrow.ChunkedArray] + ) -> "TimeArray": + """Convert to dbtime data from an Arrow array. + + See: + https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow + """ + # We can't call combine_chunks on an empty array, so short-circuit the + # rest of the function logic for this special case. + if len(array) == 0: + return TimeArray(numpy.array([], dtype="datetime64[ns]")) + + # We can't cast to timestamp("ns"), but time64("ns") has the same + # memory layout: 64-bit integers representing the number of nanoseconds + # since the datetime epoch (midnight 1970-01-01). + array = pyarrow.compute.cast(array, pyarrow.time64("ns")) + + # ChunkedArray has no "view" method, so combine into an Array. + if isinstance(array, pyarrow.ChunkedArray): + array = array.combine_chunks() + + array = array.view(pyarrow.timestamp("ns")) + np_array = array.to_numpy(zero_copy_only=False) + return TimeArray(np_array) + class TimeArray(core.BaseDatetimeArray): """ @@ -61,8 +92,6 @@ class TimeArray(core.BaseDatetimeArray): # Data are stored as datetime64 values with a date of Jan 1, 1970 dtype = TimeDtype() - _epoch = datetime.datetime(1970, 1, 1) - _npepoch = numpy.datetime64(_epoch) @classmethod def _datetime( @@ -75,8 +104,21 @@ def _datetime( r"(?:\.(?P\d*))?)?)?\s*$" ).match, ): - if isinstance(scalar, datetime.time): - return datetime.datetime.combine(cls._epoch, scalar) + # Convert pyarrow values to datetime.time. + if isinstance(scalar, (pyarrow.Time32Scalar, pyarrow.Time64Scalar)): + scalar = ( + scalar.cast(pyarrow.time64("ns")) + .cast(pyarrow.int64()) + .cast(pyarrow.timestamp("ns")) + .as_py() + ) + + if scalar is None: + return None + elif isinstance(scalar, datetime.time): + return datetime.datetime.combine(_EPOCH, scalar) + elif isinstance(scalar, pandas.Timestamp): + return scalar.to_datetime64() elif isinstance(scalar, str): # iso string parsed = match_fn(scalar) @@ -113,7 +155,7 @@ def _box_func(self, x): __return_deltas = {"timedelta", "timedelta64", "timedelta64[ns]", " "DateArray": + """Convert to dbdate data from an Arrow array. + + See: + https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow + """ + array = pyarrow.compute.cast(array, pyarrow.timestamp("ns")) + np_array = array.to_numpy() + return DateArray(np_array) + class DateArray(core.BaseDatetimeArray): """ @@ -161,7 +226,13 @@ def _datetime( scalar, match_fn=re.compile(r"\s*(?P\d+)-(?P\d+)-(?P\d+)\s*$").match, ): - if isinstance(scalar, datetime.date): + # Convert pyarrow values to datetime.date. + if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)): + scalar = scalar.as_py() + + if scalar is None: + return None + elif isinstance(scalar, datetime.date): return datetime.datetime(scalar.year, scalar.month, scalar.day) elif isinstance(scalar, str): match = match_fn(scalar) @@ -197,8 +268,14 @@ def astype(self, dtype, copy=True): return super().astype(dtype, copy=copy) def __arrow_array__(self, type=None): - return pyarrow.array( - self._ndarray, type=type if type is not None else pyarrow.date32(), + """Convert to an Arrow array from dbdate data. + + See: + https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow + """ + array = pyarrow.array(self._ndarray, type=pyarrow.timestamp("ns")) + return pyarrow.compute.cast( + array, type if type is not None else pyarrow.date32(), ) def __add__(self, other): @@ -206,7 +283,7 @@ def __add__(self, other): return self.astype("object") + other if isinstance(other, TimeArray): - return (other._ndarray - other._npepoch) + self._ndarray + return (other._ndarray - _NPEPOCH) + self._ndarray return super().__add__(other) diff --git a/db_dtypes/core.py b/db_dtypes/core.py index fbc784e..c8f3ad4 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -17,6 +17,7 @@ import numpy import pandas from pandas._libs import NaT +import pandas.api.extensions import pandas.compat.numpy.function import pandas.core.algorithms import pandas.core.arrays @@ -32,7 +33,7 @@ pandas_release = pandas_backports.pandas_release -class BaseDatetimeDtype(pandas.core.dtypes.base.ExtensionDtype): +class BaseDatetimeDtype(pandas.api.extensions.ExtensionDtype): na_value = NaT kind = "o" names = None @@ -60,10 +61,7 @@ def __init__(self, values, dtype=None, copy: bool = False): @classmethod def __ndarray(cls, scalars): - return numpy.array( - [None if scalar is None else cls._datetime(scalar) for scalar in scalars], - "M8[ns]", - ) + return numpy.array([cls._datetime(scalar) for scalar in scalars], "M8[ns]",) @classmethod def _from_sequence(cls, scalars, *, dtype=None, copy=False): diff --git a/tests/unit/test_arrow.py b/tests/unit/test_arrow.py index d3745ea..5f45a90 100644 --- a/tests/unit/test_arrow.py +++ b/tests/unit/test_arrow.py @@ -13,160 +13,314 @@ # limitations under the License. import datetime as dt +from typing import Optional import pandas +import pandas.api.extensions +import pandas.testing import pyarrow import pytest -# To register the types. -import db_dtypes # noqa +import db_dtypes -@pytest.mark.parametrize( - ("series", "expected"), +SECOND_NANOS = 1_000_000_000 +MINUTE_NANOS = 60 * SECOND_NANOS +HOUR_NANOS = 60 * MINUTE_NANOS + + +def types_mapper( + pyarrow_type: pyarrow.DataType, +) -> Optional[pandas.api.extensions.ExtensionDtype]: + type_str = str(pyarrow_type) + + if type_str.startswith("date32") or type_str.startswith("date64"): + return db_dtypes.DateDtype + elif type_str.startswith("time32") or type_str.startswith("time64"): + return db_dtypes.TimeDtype + else: + # Use default type mapping. + return None + + +SERIES_ARRAYS_DEFAULT_TYPES = [ + (pandas.Series([], dtype="dbdate"), pyarrow.array([], type=pyarrow.date32())), ( - (pandas.Series([], dtype="dbdate"), pyarrow.array([], type=pyarrow.date32())), - ( - pandas.Series([None, None, None], dtype="dbdate"), - pyarrow.array([None, None, None], type=pyarrow.date32()), - ), - ( - pandas.Series( - [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], dtype="dbdate" - ), - pyarrow.array( - [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], - type=pyarrow.date32(), - ), + pandas.Series([None, None, None], dtype="dbdate"), + pyarrow.array([None, None, None], type=pyarrow.date32()), + ), + ( + pandas.Series( + [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], dtype="dbdate" ), - ( - pandas.Series( - [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], - dtype="dbdate", - ), - pyarrow.array( - [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], - type=pyarrow.date32(), - ), + pyarrow.array( + [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], type=pyarrow.date32(), ), - ( - pandas.Series([], dtype="dbtime"), - pyarrow.array([], type=pyarrow.time64("ns")), + ), + ( + pandas.Series( + [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], + dtype="dbdate", ), - ( - pandas.Series([None, None, None], dtype="dbtime"), - pyarrow.array([None, None, None], type=pyarrow.time64("ns")), + pyarrow.array( + [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], + type=pyarrow.date32(), ), - ( - pandas.Series( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], - dtype="dbtime", - ), - pyarrow.array( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], - type=pyarrow.time64("ns"), - ), + ), + (pandas.Series([], dtype="dbtime"), pyarrow.array([], type=pyarrow.time64("ns")),), + ( + pandas.Series([None, None, None], dtype="dbtime"), + pyarrow.array([None, None, None], type=pyarrow.time64("ns")), + ), + ( + pandas.Series( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], dtype="dbtime", ), - ( - pandas.Series( - [ - dt.time(0, 0, 0, 0), - dt.time(12, 30, 15, 125_000), - dt.time(23, 59, 59, 999_999), - ], - dtype="dbtime", - ), - pyarrow.array( - [ - dt.time(0, 0, 0, 0), - dt.time(12, 30, 15, 125_000), - dt.time(23, 59, 59, 999_999), - ], - type=pyarrow.time64("ns"), - ), + pyarrow.array( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], + type=pyarrow.time64("ns"), ), ), -) + ( + pandas.Series( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_999), + ], + dtype="dbtime", + ), + pyarrow.array( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_999), + ], + type=pyarrow.time64("ns"), + ), + ), +] +SERIES_ARRAYS_CUSTOM_ARROW_TYPES = [ + (pandas.Series([], dtype="dbdate"), pyarrow.array([], type=pyarrow.date64())), + ( + pandas.Series([None, None, None], dtype="dbdate"), + pyarrow.array([None, None, None], type=pyarrow.date64()), + ), + ( + pandas.Series( + [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], dtype="dbdate" + ), + pyarrow.array( + [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], type=pyarrow.date64(), + ), + ), + ( + pandas.Series( + [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], + dtype="dbdate", + ), + pyarrow.array( + [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], + type=pyarrow.date64(), + ), + ), + (pandas.Series([], dtype="dbtime"), pyarrow.array([], type=pyarrow.time32("ms")),), + ( + pandas.Series([None, None, None], dtype="dbtime"), + pyarrow.array([None, None, None], type=pyarrow.time32("ms")), + ), + ( + pandas.Series( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_000)], dtype="dbtime", + ), + pyarrow.array( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_000)], + type=pyarrow.time32("ms"), + ), + ), + ( + pandas.Series( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_000), + ], + dtype="dbtime", + ), + pyarrow.array( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_000), + ], + type=pyarrow.time32("ms"), + ), + ), + ( + pandas.Series( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], dtype="dbtime", + ), + pyarrow.array( + [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], + type=pyarrow.time64("us"), + ), + ), + ( + pandas.Series( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_999), + ], + dtype="dbtime", + ), + pyarrow.array( + [ + dt.time(0, 0, 0, 0), + dt.time(12, 30, 15, 125_000), + dt.time(23, 59, 59, 999_999), + ], + type=pyarrow.time64("us"), + ), + ), + ( + pandas.Series( + [ + # Only microseconds are supported when reading data. See: + # https://github.com/googleapis/python-db-dtypes-pandas/issues/19 + # Still, round-trip with pyarrow nanosecond precision scalars + # is supported. + pyarrow.scalar(0, pyarrow.time64("ns")), + pyarrow.scalar( + 12 * HOUR_NANOS + + 30 * MINUTE_NANOS + + 15 * SECOND_NANOS + + 123_456_789, + pyarrow.time64("ns"), + ), + pyarrow.scalar( + 23 * HOUR_NANOS + + 59 * MINUTE_NANOS + + 59 * SECOND_NANOS + + 999_999_999, + pyarrow.time64("ns"), + ), + ], + dtype="dbtime", + ), + pyarrow.array( + [ + 0, + 12 * HOUR_NANOS + 30 * MINUTE_NANOS + 15 * SECOND_NANOS + 123_456_789, + 23 * HOUR_NANOS + 59 * MINUTE_NANOS + 59 * SECOND_NANOS + 999_999_999, + ], + type=pyarrow.time64("ns"), + ), + ), +] + + +@pytest.mark.parametrize(("series", "expected"), SERIES_ARRAYS_DEFAULT_TYPES) def test_to_arrow(series, expected): array = pyarrow.array(series) assert array.equals(expected) +@pytest.mark.parametrize(("series", "expected"), SERIES_ARRAYS_CUSTOM_ARROW_TYPES) +def test_to_arrow_w_arrow_type(series, expected): + array = pyarrow.array(series, type=expected.type) + assert array.equals(expected) + + @pytest.mark.parametrize( - ("series", "expected"), - ( - (pandas.Series([], dtype="dbdate"), pyarrow.array([], type=pyarrow.date64())), - ( - pandas.Series([None, None, None], dtype="dbdate"), - pyarrow.array([None, None, None], type=pyarrow.date64()), - ), - ( - pandas.Series( - [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], dtype="dbdate" - ), - pyarrow.array( - [dt.date(2021, 9, 27), None, dt.date(2011, 9, 27)], - type=pyarrow.date64(), - ), - ), - ( - pandas.Series( - [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], - dtype="dbdate", - ), - pyarrow.array( - [dt.date(1677, 9, 22), dt.date(1970, 1, 1), dt.date(2262, 4, 11)], - type=pyarrow.date64(), - ), - ), - ( - pandas.Series([], dtype="dbtime"), - pyarrow.array([], type=pyarrow.time32("ms")), - ), - ( - pandas.Series([None, None, None], dtype="dbtime"), - pyarrow.array([None, None, None], type=pyarrow.time32("ms")), - ), - ( - pandas.Series( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_000)], - dtype="dbtime", - ), + ["expected", "pyarrow_array"], + SERIES_ARRAYS_DEFAULT_TYPES + SERIES_ARRAYS_CUSTOM_ARROW_TYPES, +) +def test_series_from_arrow(pyarrow_array: pyarrow.Array, expected: pandas.Series): + # Convert to RecordBatch because types_mapper argument is ignored when + # using a pyarrow.Array. https://issues.apache.org/jira/browse/ARROW-9664 + record_batch = pyarrow.RecordBatch.from_arrays([pyarrow_array], ["test_col"]) + dataframe = record_batch.to_pandas(date_as_object=False, types_mapper=types_mapper) + series = dataframe["test_col"] + pandas.testing.assert_series_equal(series, expected, check_names=False) + + +@pytest.mark.parametrize( + ["expected", "pyarrow_array"], + SERIES_ARRAYS_DEFAULT_TYPES + SERIES_ARRAYS_CUSTOM_ARROW_TYPES, +) +def test_series_from_arrow_scalars( + pyarrow_array: pyarrow.Array, expected: pandas.Series +): + scalars = [] + for scalar in pyarrow_array: + scalars.append(scalar) + assert isinstance(scalar, pyarrow.Scalar) + series = pandas.Series(scalars, dtype=expected.dtype) + pandas.testing.assert_series_equal(series, expected) + + +def test_dbtime_series_from_arrow_array(): + """Test to explicitly check Array -> Series conversion.""" + array = pyarrow.array([dt.time(15, 21, 0, 123_456)], type=pyarrow.time64("us")) + assert isinstance(array, pyarrow.Array) + assert not isinstance(array, pyarrow.ChunkedArray) + series = pandas.Series(db_dtypes.TimeDtype.__from_arrow__(array)) + expected = pandas.Series([dt.time(15, 21, 0, 123_456)], dtype="dbtime") + pandas.testing.assert_series_equal(series, expected) + + +def test_dbtime_series_from_arrow_chunkedarray(): + """Test to explicitly check ChunkedArray -> Series conversion.""" + array1 = pyarrow.array([dt.time(15, 21, 0, 123_456)], type=pyarrow.time64("us")) + array2 = pyarrow.array([dt.time(0, 0, 0, 0)], type=pyarrow.time64("us")) + array = pyarrow.chunked_array([array1, array2]) + assert isinstance(array, pyarrow.ChunkedArray) + series = pandas.Series(db_dtypes.TimeDtype.__from_arrow__(array)) + expected = pandas.Series( + [dt.time(15, 21, 0, 123_456), dt.time(0, 0, 0, 0)], dtype="dbtime" + ) + pandas.testing.assert_series_equal(series, expected) + + +def test_dataframe_from_arrow(): + record_batch = pyarrow.RecordBatch.from_arrays( + [ pyarrow.array( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_000)], - type=pyarrow.time32("ms"), - ), - ), - ( - pandas.Series( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], - dtype="dbtime", + [dt.date(2021, 11, 4), dt.date(2038, 1, 20), None, dt.date(1970, 1, 1)], + type=pyarrow.date32(), ), pyarrow.array( - [dt.time(0, 0, 0, 0), None, dt.time(23, 59, 59, 999_999)], - type=pyarrow.time64("us"), - ), - ), - ( - pandas.Series( [ - dt.time(0, 0, 0, 0), - dt.time(12, 30, 15, 125_000), + dt.time(10, 7, 8, 995_325), dt.time(23, 59, 59, 999_999), + None, + dt.time(0, 0, 0, 0), ], - dtype="dbtime", + type=pyarrow.time64("us"), ), - pyarrow.array( + pyarrow.array([1, 2, 3, 4]), + ], + ["date_col", "time_col", "int_col"], + ) + dataframe = record_batch.to_pandas(date_as_object=False, types_mapper=types_mapper) + expected = pandas.DataFrame( + { + "date_col": pandas.Series( + [dt.date(2021, 11, 4), dt.date(2038, 1, 20), None, dt.date(1970, 1, 1)], + dtype="dbdate", + ), + "time_col": pandas.Series( [ - dt.time(0, 0, 0, 0), - dt.time(12, 30, 15, 125_000), + dt.time(10, 7, 8, 995_325), dt.time(23, 59, 59, 999_999), + None, + dt.time(0, 0, 0, 0), ], - type=pyarrow.time64("us"), + dtype="dbtime", ), - ), - ), -) -def test_to_arrow_w_arrow_type(series, expected): - array = pyarrow.array(series, type=expected.type) - assert array.equals(expected) + "int_col": [1, 2, 3, 4], + }, + columns=["date_col", "time_col", "int_col"], + ) + pandas.testing.assert_frame_equal(dataframe, expected)