|
10 | 10 | )
|
11 | 11 |
|
12 | 12 | import numpy as np
|
13 |
| -import pyarrow as pa |
14 | 13 |
|
15 | 14 | from pandas._libs import missing as libmissing
|
16 | 15 |
|
@@ -77,6 +76,8 @@ def string(
|
77 | 76 | if mode not in valid_modes:
|
78 | 77 | raise ValueError(f"mode must be one of {valid_modes}, got {mode}")
|
79 | 78 | if backend == "pyarrow":
|
| 79 | + import pyarrow as pa |
| 80 | + |
80 | 81 | if mode == "string":
|
81 | 82 | pa_type = pa.large_string() if large else pa.string()
|
82 | 83 | else: # mode == "binary"
|
@@ -128,6 +129,8 @@ def datetime(
|
128 | 129 | return DatetimeTZDtype(unit=unit, tz=tz)
|
129 | 130 | return np.dtype(f"datetime64[{unit}]")
|
130 | 131 | else: # pyarrow
|
| 132 | + import pyarrow as pa |
| 133 | + |
131 | 134 | return ArrowDtype(pa.timestamp(unit, tz=tz))
|
132 | 135 |
|
133 | 136 |
|
@@ -167,24 +170,25 @@ def integer(
|
167 | 170 |
|
168 | 171 | if backend == "numpy":
|
169 | 172 | return np.dtype(f"int{bits}")
|
170 |
| - |
171 |
| - if backend == "pandas": |
| 173 | + elif backend == "pandas": |
172 | 174 | if bits == 8:
|
173 | 175 | return Int8Dtype()
|
174 | 176 | elif bits == 16:
|
175 | 177 | return Int16Dtype()
|
176 | 178 | elif bits == 32:
|
177 | 179 | return Int32Dtype()
|
178 |
| - elif bits == 64: |
| 180 | + else: # bits == 64 |
179 | 181 | return Int64Dtype()
|
180 | 182 | elif backend == "pyarrow":
|
| 183 | + import pyarrow as pa |
| 184 | + |
181 | 185 | if bits == 8:
|
182 | 186 | return ArrowDtype(pa.int8())
|
183 | 187 | elif bits == 16:
|
184 | 188 | return ArrowDtype(pa.int16())
|
185 | 189 | elif bits == 32:
|
186 | 190 | return ArrowDtype(pa.int32())
|
187 |
| - elif bits == 64: |
| 191 | + else: # bits == 64 |
188 | 192 | return ArrowDtype(pa.int64())
|
189 | 193 | else:
|
190 | 194 | raise ValueError(f"Unsupported backend: {backend!r}")
|
@@ -224,16 +228,17 @@ def floating(
|
224 | 228 |
|
225 | 229 | if backend == "numpy":
|
226 | 230 | return np.dtype(f"float{bits}")
|
227 |
| - |
228 |
| - if backend == "pandas": |
| 231 | + elif backend == "pandas": |
229 | 232 | if bits == 32:
|
230 | 233 | return Float32Dtype()
|
231 |
| - elif bits == 64: |
| 234 | + else: # bits == 64 |
232 | 235 | return Float64Dtype()
|
233 | 236 | elif backend == "pyarrow":
|
| 237 | + import pyarrow as pa |
| 238 | + |
234 | 239 | if bits == 32:
|
235 | 240 | return ArrowDtype(pa.float32())
|
236 |
| - elif bits == 64: |
| 241 | + else: # bits == 64 |
237 | 242 | return ArrowDtype(pa.float64())
|
238 | 243 | else:
|
239 | 244 | raise ValueError(f"Unsupported backend: {backend!r}")
|
@@ -270,6 +275,8 @@ def decimal(
|
270 | 275 | decimal256[40, 5][pyarrow]
|
271 | 276 | """
|
272 | 277 | if backend == "pyarrow":
|
| 278 | + import pyarrow as pa |
| 279 | + |
273 | 280 | if precision <= 38:
|
274 | 281 | return ArrowDtype(pa.decimal128(precision, scale))
|
275 | 282 | return ArrowDtype(pa.decimal256(precision, scale))
|
@@ -302,6 +309,8 @@ def boolean(
|
302 | 309 | if backend == "numpy":
|
303 | 310 | return BooleanDtype()
|
304 | 311 | else: # pyarrow
|
| 312 | + import pyarrow as pa |
| 313 | + |
305 | 314 | return ArrowDtype(pa.bool_())
|
306 | 315 |
|
307 | 316 |
|
@@ -344,6 +353,8 @@ def list(
|
344 | 353 | if backend == "numpy":
|
345 | 354 | return np.dtype("object")
|
346 | 355 | else: # pyarrow
|
| 356 | + import pyarrow as pa |
| 357 | + |
347 | 358 | if value_type is None:
|
348 | 359 | value_type = pa.int64()
|
349 | 360 | pa_type = pa.large_list(value_type) if large else pa.list_(value_type)
|
@@ -396,6 +407,8 @@ def categorical(
|
396 | 407 | if backend == "numpy":
|
397 | 408 | return CategoricalDtype(categories=categories, ordered=ordered)
|
398 | 409 | else: # pyarrow
|
| 410 | + import pyarrow as pa |
| 411 | + |
399 | 412 | index_type = pa.int32() if index_type is None else index_type
|
400 | 413 | value_type = pa.string() if value_type is None else value_type
|
401 | 414 | return ArrowDtype(pa.dictionary(index_type, value_type))
|
@@ -437,6 +450,8 @@ def interval(
|
437 | 450 | if backend == "numpy":
|
438 | 451 | return IntervalDtype(subtype=subtype, closed=closed)
|
439 | 452 | else: # pyarrow
|
| 453 | + import pyarrow as pa |
| 454 | + |
440 | 455 | if subtype is not None:
|
441 | 456 | return ArrowDtype(
|
442 | 457 | pa.struct(
|
@@ -491,6 +506,8 @@ def period(
|
491 | 506 | if backend == "numpy":
|
492 | 507 | return PeriodDtype(freq=freq)
|
493 | 508 | else: # pyarrow
|
| 509 | + import pyarrow as pa |
| 510 | + |
494 | 511 | return ArrowDtype(pa.month_day_nano_interval())
|
495 | 512 |
|
496 | 513 |
|
@@ -590,6 +607,8 @@ def date(
|
590 | 607 |
|
591 | 608 | if backend != "pyarrow":
|
592 | 609 | raise ValueError("Date types are only supported with PyArrow backend.")
|
| 610 | + import pyarrow as pa |
| 611 | + |
593 | 612 | return ArrowDtype(pa.date32() if unit == "day" else pa.date64())
|
594 | 613 |
|
595 | 614 |
|
@@ -629,6 +648,8 @@ def duration(
|
629 | 648 | if backend == "numpy":
|
630 | 649 | return np.dtype(f"timedelta64[{unit}]")
|
631 | 650 | else: # pyarrow
|
| 651 | + import pyarrow as pa |
| 652 | + |
632 | 653 | return ArrowDtype(pa.duration(unit))
|
633 | 654 |
|
634 | 655 |
|
@@ -677,6 +698,8 @@ def map(
|
677 | 698 | """
|
678 | 699 | if backend != "pyarrow":
|
679 | 700 | raise ValueError("Map types are only supported with PyArrow backend.")
|
| 701 | + import pyarrow as pa |
| 702 | + |
680 | 703 | return ArrowDtype(pa.map_(index_type, value_type))
|
681 | 704 |
|
682 | 705 |
|
@@ -724,14 +747,10 @@ def struct(
|
724 | 747 | 1 (2, Bob)
|
725 | 748 | dtype: struct<id: int32, name: string>[pyarrow]
|
726 | 749 | """
|
727 |
| - if backend != "pyarrow": |
728 |
| - raise ValueError("Struct types are only supported with PyArrow backend.") |
729 |
| - # Validate that fields is a list of (str, type) tuples |
730 |
| - for field in fields: |
731 |
| - if ( |
732 |
| - not isinstance(field, tuple) |
733 |
| - or len(field) != 2 |
734 |
| - or not isinstance(field[0], str) |
735 |
| - ): |
736 |
| - raise ValueError("Each field must be a tuple of (str, type), got {field}") |
737 |
| - return ArrowDtype(pa.struct(fields)) |
| 750 | + if backend == "pyarrow": |
| 751 | + import pyarrow as pa |
| 752 | + |
| 753 | + pa_fields = [(name, getattr(typ, "pyarrow_dtype", typ)) for name, typ in fields] |
| 754 | + return ArrowDtype(pa.struct(pa_fields)) |
| 755 | + else: |
| 756 | + raise ValueError(f"Unsupported backend: {backend!r}") |
0 commit comments