Skip to content

Commit facc7b0

Browse files
authored
feat: support conversion from pyarrow RecordBatch to pandas DataFrame (#39)
* feat: support conversion from pyarrow RecordBatch to pandas DataFrame * hack together working implementation TODO: add tests for constructing pandas Series with pyarrow scalars * fix unit test coverage, optimize arrow to numpy conversion * apply same optimizations to to_arrow conversion * remove redundant to_numpy now that to_arrow doesn't use it * be explicit about chunked array vs array * add docstrings to arrow conversion functions * add test case for round-trip to/from pyarrow nanosecond-precision time scalars * add time32("ms") test case without nulls for completeness
1 parent a31d55d commit facc7b0

File tree

3 files changed

+375
-146
lines changed

3 files changed

+375
-146
lines changed

db_dtypes/__init__.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import datetime
1919
import re
20+
from typing import Union
2021

2122
import numpy
2223
import packaging.version
@@ -29,13 +30,16 @@
2930
import pandas.core.dtypes.generic
3031
import pandas.core.nanops
3132
import pyarrow
33+
import pyarrow.compute
3234

3335
from db_dtypes.version import __version__
3436
from db_dtypes import core
3537

3638

3739
date_dtype_name = "dbdate"
3840
time_dtype_name = "dbtime"
41+
_EPOCH = datetime.datetime(1970, 1, 1)
42+
_NPEPOCH = numpy.datetime64(_EPOCH)
3943

4044
pandas_release = packaging.version.parse(pandas.__version__).release
4145

@@ -52,6 +56,33 @@ class TimeDtype(core.BaseDatetimeDtype):
5256
def construct_array_type(self):
5357
return TimeArray
5458

59+
@staticmethod
60+
def __from_arrow__(
61+
array: Union[pyarrow.Array, pyarrow.ChunkedArray]
62+
) -> "TimeArray":
63+
"""Convert to dbtime data from an Arrow array.
64+
65+
See:
66+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
67+
"""
68+
# We can't call combine_chunks on an empty array, so short-circuit the
69+
# rest of the function logic for this special case.
70+
if len(array) == 0:
71+
return TimeArray(numpy.array([], dtype="datetime64[ns]"))
72+
73+
# We can't cast to timestamp("ns"), but time64("ns") has the same
74+
# memory layout: 64-bit integers representing the number of nanoseconds
75+
# since the datetime epoch (midnight 1970-01-01).
76+
array = pyarrow.compute.cast(array, pyarrow.time64("ns"))
77+
78+
# ChunkedArray has no "view" method, so combine into an Array.
79+
if isinstance(array, pyarrow.ChunkedArray):
80+
array = array.combine_chunks()
81+
82+
array = array.view(pyarrow.timestamp("ns"))
83+
np_array = array.to_numpy(zero_copy_only=False)
84+
return TimeArray(np_array)
85+
5586

5687
class TimeArray(core.BaseDatetimeArray):
5788
"""
@@ -61,8 +92,6 @@ class TimeArray(core.BaseDatetimeArray):
6192
# Data are stored as datetime64 values with a date of Jan 1, 1970
6293

6394
dtype = TimeDtype()
64-
_epoch = datetime.datetime(1970, 1, 1)
65-
_npepoch = numpy.datetime64(_epoch)
6695

6796
@classmethod
6897
def _datetime(
@@ -75,8 +104,21 @@ def _datetime(
75104
r"(?:\.(?P<fraction>\d*))?)?)?\s*$"
76105
).match,
77106
):
78-
if isinstance(scalar, datetime.time):
79-
return datetime.datetime.combine(cls._epoch, scalar)
107+
# Convert pyarrow values to datetime.time.
108+
if isinstance(scalar, (pyarrow.Time32Scalar, pyarrow.Time64Scalar)):
109+
scalar = (
110+
scalar.cast(pyarrow.time64("ns"))
111+
.cast(pyarrow.int64())
112+
.cast(pyarrow.timestamp("ns"))
113+
.as_py()
114+
)
115+
116+
if scalar is None:
117+
return None
118+
elif isinstance(scalar, datetime.time):
119+
return datetime.datetime.combine(_EPOCH, scalar)
120+
elif isinstance(scalar, pandas.Timestamp):
121+
return scalar.to_datetime64()
80122
elif isinstance(scalar, str):
81123
# iso string
82124
parsed = match_fn(scalar)
@@ -113,7 +155,7 @@ def _box_func(self, x):
113155
__return_deltas = {"timedelta", "timedelta64", "timedelta64[ns]", "<m8", "<m8[ns]"}
114156

115157
def astype(self, dtype, copy=True):
116-
deltas = self._ndarray - self._npepoch
158+
deltas = self._ndarray - _NPEPOCH
117159
stype = str(dtype)
118160
if stype in self.__return_deltas:
119161
return deltas
@@ -122,15 +164,25 @@ def astype(self, dtype, copy=True):
122164
else:
123165
return super().astype(dtype, copy=copy)
124166

125-
if pandas_release < (1,):
167+
def __arrow_array__(self, type=None):
168+
"""Convert to an Arrow array from dbtime data.
126169
127-
def to_numpy(self, dtype="object"):
128-
return self.astype(dtype)
170+
See:
171+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
172+
"""
173+
array = pyarrow.array(self._ndarray, type=pyarrow.timestamp("ns"))
129174

130-
def __arrow_array__(self, type=None):
131-
return pyarrow.array(
132-
self.to_numpy(dtype="object"),
133-
type=type if type is not None else pyarrow.time64("ns"),
175+
# ChunkedArray has no "view" method, so combine into an Array.
176+
array = (
177+
array.combine_chunks() if isinstance(array, pyarrow.ChunkedArray) else array
178+
)
179+
180+
# We can't cast to time64("ns"), but timestamp("ns") has the same
181+
# memory layout: 64-bit integers representing the number of nanoseconds
182+
# since the datetime epoch (midnight 1970-01-01).
183+
array = array.view(pyarrow.time64("ns"))
184+
return pyarrow.compute.cast(
185+
array, type if type is not None else pyarrow.time64("ns"),
134186
)
135187

136188

@@ -146,6 +198,19 @@ class DateDtype(core.BaseDatetimeDtype):
146198
def construct_array_type(self):
147199
return DateArray
148200

201+
@staticmethod
202+
def __from_arrow__(
203+
array: Union[pyarrow.Array, pyarrow.ChunkedArray]
204+
) -> "DateArray":
205+
"""Convert to dbdate data from an Arrow array.
206+
207+
See:
208+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
209+
"""
210+
array = pyarrow.compute.cast(array, pyarrow.timestamp("ns"))
211+
np_array = array.to_numpy()
212+
return DateArray(np_array)
213+
149214

150215
class DateArray(core.BaseDatetimeArray):
151216
"""
@@ -161,7 +226,13 @@ def _datetime(
161226
scalar,
162227
match_fn=re.compile(r"\s*(?P<year>\d+)-(?P<month>\d+)-(?P<day>\d+)\s*$").match,
163228
):
164-
if isinstance(scalar, datetime.date):
229+
# Convert pyarrow values to datetime.date.
230+
if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)):
231+
scalar = scalar.as_py()
232+
233+
if scalar is None:
234+
return None
235+
elif isinstance(scalar, datetime.date):
165236
return datetime.datetime(scalar.year, scalar.month, scalar.day)
166237
elif isinstance(scalar, str):
167238
match = match_fn(scalar)
@@ -197,16 +268,22 @@ def astype(self, dtype, copy=True):
197268
return super().astype(dtype, copy=copy)
198269

199270
def __arrow_array__(self, type=None):
200-
return pyarrow.array(
201-
self._ndarray, type=type if type is not None else pyarrow.date32(),
271+
"""Convert to an Arrow array from dbdate data.
272+
273+
See:
274+
https://pandas.pydata.org/pandas-docs/stable/development/extending.html#compatibility-with-apache-arrow
275+
"""
276+
array = pyarrow.array(self._ndarray, type=pyarrow.timestamp("ns"))
277+
return pyarrow.compute.cast(
278+
array, type if type is not None else pyarrow.date32(),
202279
)
203280

204281
def __add__(self, other):
205282
if isinstance(other, pandas.DateOffset):
206283
return self.astype("object") + other
207284

208285
if isinstance(other, TimeArray):
209-
return (other._ndarray - other._npepoch) + self._ndarray
286+
return (other._ndarray - _NPEPOCH) + self._ndarray
210287

211288
return super().__add__(other)
212289

db_dtypes/core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy
1818
import pandas
1919
from pandas._libs import NaT
20+
import pandas.api.extensions
2021
import pandas.compat.numpy.function
2122
import pandas.core.algorithms
2223
import pandas.core.arrays
@@ -32,7 +33,7 @@
3233
pandas_release = pandas_backports.pandas_release
3334

3435

35-
class BaseDatetimeDtype(pandas.core.dtypes.base.ExtensionDtype):
36+
class BaseDatetimeDtype(pandas.api.extensions.ExtensionDtype):
3637
na_value = NaT
3738
kind = "o"
3839
names = None
@@ -60,10 +61,7 @@ def __init__(self, values, dtype=None, copy: bool = False):
6061

6162
@classmethod
6263
def __ndarray(cls, scalars):
63-
return numpy.array(
64-
[None if scalar is None else cls._datetime(scalar) for scalar in scalars],
65-
"M8[ns]",
66-
)
64+
return numpy.array([cls._datetime(scalar) for scalar in scalars], "M8[ns]",)
6765

6866
@classmethod
6967
def _from_sequence(cls, scalars, *, dtype=None, copy=False):

0 commit comments

Comments
 (0)