Skip to content

Commit c212e3c

Browse files
Refactor pyarrow import handling to raise informative ImportError if not installed
1 parent 7899f02 commit c212e3c

File tree

2 files changed

+104
-27
lines changed

2 files changed

+104
-27
lines changed

pandas/core/dtypes/dtypes.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TYPE_CHECKING,
1717
Any,
1818
cast,
19-
List,
2019
)
2120
import warnings
2221
import zoneinfo
@@ -676,7 +675,7 @@ def _is_boolean(self) -> bool:
676675

677676
return is_bool_dtype(self.categories)
678677

679-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
678+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
680679
# check if we have all categorical dtype with identical categories
681680
if all(isinstance(x, CategoricalDtype) for x in dtypes):
682681
first = dtypes[0]
@@ -967,7 +966,7 @@ def __setstate__(self, state) -> None:
967966
self._tz = state["tz"]
968967
self._unit = state["unit"]
969968

970-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
969+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
971970
if all(isinstance(t, DatetimeTZDtype) and t.tz == self.tz for t in dtypes):
972971
np_dtype = np.max([cast(DatetimeTZDtype, t).base for t in [self, *dtypes]])
973972
unit = np.datetime_data(np_dtype)[0]
@@ -1480,7 +1479,7 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> IntervalArray:
14801479
)
14811480
return IntervalArray._concat_same_type(results)
14821481

1483-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
1482+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
14841483
if not all(isinstance(x, IntervalDtype) for x in dtypes):
14851484
return None
14861485

@@ -1678,7 +1677,7 @@ def from_numpy_dtype(cls, dtype: np.dtype) -> BaseMaskedDtype:
16781677
else:
16791678
raise NotImplementedError(dtype)
16801679

1681-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
1680+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
16821681
# We unwrap any masked dtypes, find the common dtype we would use
16831682
# for that, then re-mask the result.
16841683
from pandas.core.dtypes.cast import find_common_type
@@ -2105,7 +2104,7 @@ def _subtype_with_str(self):
21052104
return type(self.fill_value)
21062105
return self.subtype
21072106

2108-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
2107+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
21092108
# TODO for now only handle SparseDtypes and numpy dtypes => extend
21102109
# with other compatible extension dtypes
21112110
from pandas.core.dtypes.cast import np_find_common_type
@@ -2420,7 +2419,7 @@ def _is_boolean(self) -> bool:
24202419
"""
24212420
return pa.types.is_boolean(self.pyarrow_dtype)
24222421

2423-
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> DtypeObj | None:
2422+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
24242423
# We unwrap any masked dtypes, find the common dtype we would use
24252424
# for that, then re-mask the result.
24262425
# Mirrors BaseMaskedDtype

pandas/core/dtypes/factory.py

Lines changed: 98 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,13 @@ def string(
7676
if mode not in valid_modes:
7777
raise ValueError(f"mode must be one of {valid_modes}, got {mode}")
7878
if backend == "pyarrow":
79-
import pyarrow as pa
80-
79+
try:
80+
import pyarrow as pa
81+
except ImportError as err:
82+
raise ImportError(
83+
"pyarrow is required for the 'pyarrow' backend. "
84+
"Please install pyarrow to use this feature."
85+
) from err
8186
if mode == "string":
8287
pa_type = pa.large_string() if large else pa.string()
8388
else: # mode == "binary"
@@ -129,8 +134,13 @@ def datetime(
129134
return DatetimeTZDtype(unit=unit, tz=tz)
130135
return np.dtype(f"datetime64[{unit}]")
131136
else: # pyarrow
132-
import pyarrow as pa
133-
137+
try:
138+
import pyarrow as pa
139+
except ImportError as err:
140+
raise ImportError(
141+
"pyarrow is required for the 'pyarrow' backend. "
142+
"Please install pyarrow to use this feature."
143+
) from err
134144
return ArrowDtype(pa.timestamp(unit, tz=tz))
135145

136146

@@ -180,7 +190,13 @@ def integer(
180190
else: # bits == 64
181191
return Int64Dtype()
182192
elif backend == "pyarrow":
183-
import pyarrow as pa
193+
try:
194+
import pyarrow as pa
195+
except ImportError as err:
196+
raise ImportError(
197+
"pyarrow is required for the 'pyarrow' backend. "
198+
"Please install pyarrow to use this feature."
199+
) from err
184200

185201
if bits == 8:
186202
return ArrowDtype(pa.int8())
@@ -234,7 +250,13 @@ def floating(
234250
else: # bits == 64
235251
return Float64Dtype()
236252
elif backend == "pyarrow":
237-
import pyarrow as pa
253+
try:
254+
import pyarrow as pa
255+
except ImportError as err:
256+
raise ImportError(
257+
"pyarrow is required for the 'pyarrow' backend. "
258+
"Please install pyarrow to use this feature."
259+
) from err
238260

239261
if bits == 32:
240262
return ArrowDtype(pa.float32())
@@ -275,8 +297,13 @@ def decimal(
275297
decimal256[40, 5][pyarrow]
276298
"""
277299
if backend == "pyarrow":
278-
import pyarrow as pa
279-
300+
try:
301+
import pyarrow as pa
302+
except ImportError as err:
303+
raise ImportError(
304+
"pyarrow is required for the 'pyarrow' backend. "
305+
"Please install pyarrow to use this feature."
306+
) from err
280307
if precision <= 38:
281308
return ArrowDtype(pa.decimal128(precision, scale))
282309
return ArrowDtype(pa.decimal256(precision, scale))
@@ -309,8 +336,13 @@ def boolean(
309336
if backend == "numpy":
310337
return BooleanDtype()
311338
else: # pyarrow
312-
import pyarrow as pa
313-
339+
try:
340+
import pyarrow as pa
341+
except ImportError as err:
342+
raise ImportError(
343+
"pyarrow is required for the 'pyarrow' backend. "
344+
"Please install pyarrow to use this feature."
345+
) from err
314346
return ArrowDtype(pa.bool_())
315347

316348

@@ -353,7 +385,13 @@ def list(
353385
if backend == "numpy":
354386
return np.dtype("object")
355387
else: # pyarrow
356-
import pyarrow as pa
388+
try:
389+
import pyarrow as pa
390+
except ImportError as err:
391+
raise ImportError(
392+
"pyarrow is required for the 'pyarrow' backend. "
393+
"Please install pyarrow to use this feature."
394+
) from err
357395

358396
if value_type is None:
359397
value_type = pa.int64()
@@ -407,7 +445,13 @@ def categorical(
407445
if backend == "numpy":
408446
return CategoricalDtype(categories=categories, ordered=ordered)
409447
else: # pyarrow
410-
import pyarrow as pa
448+
try:
449+
import pyarrow as pa
450+
except ImportError as err:
451+
raise ImportError(
452+
"pyarrow is required for the 'pyarrow' backend. "
453+
"Please install pyarrow to use this feature."
454+
) from err
411455

412456
index_type = pa.int32() if index_type is None else index_type
413457
value_type = pa.string() if value_type is None else value_type
@@ -450,7 +494,13 @@ def interval(
450494
if backend == "numpy":
451495
return IntervalDtype(subtype=subtype, closed=closed)
452496
else: # pyarrow
453-
import pyarrow as pa
497+
try:
498+
import pyarrow as pa
499+
except ImportError as err:
500+
raise ImportError(
501+
"pyarrow is required for the 'pyarrow' backend. "
502+
"Please install pyarrow to use this feature."
503+
) from err
454504

455505
if subtype is not None:
456506
return ArrowDtype(
@@ -506,8 +556,13 @@ def period(
506556
if backend == "numpy":
507557
return PeriodDtype(freq=freq)
508558
else: # pyarrow
509-
import pyarrow as pa
510-
559+
try:
560+
import pyarrow as pa
561+
except ImportError as err:
562+
raise ImportError(
563+
"pyarrow is required for the 'pyarrow' backend. "
564+
"Please install pyarrow to use this feature."
565+
) from err
511566
return ArrowDtype(pa.month_day_nano_interval())
512567

513568

@@ -607,7 +662,13 @@ def date(
607662

608663
if backend != "pyarrow":
609664
raise ValueError("Date types are only supported with PyArrow backend.")
610-
import pyarrow as pa
665+
try:
666+
import pyarrow as pa
667+
except ImportError as err:
668+
raise ImportError(
669+
"pyarrow is required for the 'pyarrow' backend. "
670+
"Please install pyarrow to use this feature."
671+
) from err
611672

612673
return ArrowDtype(pa.date32() if unit == "day" else pa.date64())
613674

@@ -648,8 +709,13 @@ def duration(
648709
if backend == "numpy":
649710
return np.dtype(f"timedelta64[{unit}]")
650711
else: # pyarrow
651-
import pyarrow as pa
652-
712+
try:
713+
import pyarrow as pa
714+
except ImportError as err:
715+
raise ImportError(
716+
"pyarrow is required for the 'pyarrow' backend. "
717+
"Please install pyarrow to use this feature."
718+
) from err
653719
return ArrowDtype(pa.duration(unit))
654720

655721

@@ -698,7 +764,13 @@ def map(
698764
"""
699765
if backend != "pyarrow":
700766
raise ValueError("Map types are only supported with PyArrow backend.")
701-
import pyarrow as pa
767+
try:
768+
import pyarrow as pa
769+
except ImportError as err:
770+
raise ImportError(
771+
"pyarrow is required for the 'pyarrow' backend. "
772+
"Please install pyarrow to use this feature."
773+
) from err
702774

703775
return ArrowDtype(pa.map_(index_type, value_type))
704776

@@ -748,7 +820,13 @@ def struct(
748820
dtype: struct<id: int32, name: string>[pyarrow]
749821
"""
750822
if backend == "pyarrow":
751-
import pyarrow as pa
823+
try:
824+
import pyarrow as pa
825+
except ImportError as err:
826+
raise ImportError(
827+
"pyarrow is required for the 'pyarrow' backend. "
828+
"Please install pyarrow to use this feature."
829+
) from err
752830

753831
pa_fields = [(name, getattr(typ, "pyarrow_dtype", typ)) for name, typ in fields]
754832
return ArrowDtype(pa.struct(pa_fields))

0 commit comments

Comments
 (0)