1010from dataclasses import dataclass , field
1111from datetime import timedelta
1212from html import escape
13- from typing import TYPE_CHECKING , Any , overload
13+ from typing import TYPE_CHECKING , Any , cast , overload
1414
1515import numpy as np
1616import pandas as pd
17+ from numpy .typing import DTypeLike
1718from packaging .version import Version
1819
1920from xarray .core import duck_array_ops
2021from xarray .core .coordinate_transform import CoordinateTransform
22+ from xarray .core .extension_array import PandasExtensionArray
2123from xarray .core .nputils import NumpyVIndexAdapter
2224from xarray .core .options import OPTIONS
2325from xarray .core .types import T_Xarray
2830 is_duck_array ,
2931 is_duck_dask_array ,
3032 is_scalar ,
33+ is_valid_numpy_dtype ,
3134 to_0d_array ,
3235)
3336from xarray .namedarray .parallelcompat import get_chunked_array_type
3437from xarray .namedarray .pycompat import array_type , integer_types , is_chunked_array
3538
3639if TYPE_CHECKING :
37- from numpy .typing import DTypeLike
38-
3940 from xarray .core .indexes import Index
4041 from xarray .core .types import Self
4142 from xarray .core .variable import Variable
@@ -1744,27 +1745,43 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
17441745 __slots__ = ("_dtype" , "array" )
17451746
17461747 array : pd .Index
1747- _dtype : np .dtype
1748+ _dtype : np .dtype | pd . api . extensions . ExtensionDtype
17481749
1749- def __init__ (self , array : pd .Index , dtype : DTypeLike = None ):
1750+ def __init__ (
1751+ self ,
1752+ array : pd .Index ,
1753+ dtype : DTypeLike | pd .api .extensions .ExtensionDtype | None = None ,
1754+ ):
17501755 from xarray .core .indexes import safe_cast_to_index
17511756
17521757 self .array = safe_cast_to_index (array )
17531758
17541759 if dtype is None :
1755- self ._dtype = get_valid_numpy_dtype (array )
1760+ if pd .api .types .is_extension_array_dtype (array .dtype ):
1761+ cast (pd .api .extensions .ExtensionDtype , array .dtype )
1762+ self ._dtype = array .dtype
1763+ else :
1764+ self ._dtype = get_valid_numpy_dtype (array )
1765+ elif pd .api .types .is_extension_array_dtype (dtype ):
1766+ self ._dtype = cast (pd .api .extensions .ExtensionDtype , dtype )
17561767 else :
1757- self ._dtype = np .dtype (dtype )
1768+ self ._dtype = np .dtype (cast ( DTypeLike , dtype ) )
17581769
17591770 @property
1760- def dtype (self ) -> np .dtype :
1771+ def dtype (self ) -> np .dtype | pd . api . extensions . ExtensionDtype : # type: ignore[override]
17611772 return self ._dtype
17621773
17631774 def __array__ (
1764- self , dtype : np .typing .DTypeLike = None , / , * , copy : bool | None = None
1775+ self ,
1776+ dtype : np .typing .DTypeLike | None = None ,
1777+ / ,
1778+ * ,
1779+ copy : bool | None = None ,
17651780 ) -> np .ndarray :
1766- if dtype is None :
1767- dtype = self .dtype
1781+ if dtype is None and is_valid_numpy_dtype (self .dtype ):
1782+ dtype = cast (np .dtype , self .dtype )
1783+ else :
1784+ dtype = get_valid_numpy_dtype (self .array )
17681785 array = self .array
17691786 if isinstance (array , pd .PeriodIndex ):
17701787 with suppress (AttributeError ):
@@ -1776,14 +1793,18 @@ def __array__(
17761793 else :
17771794 return np .asarray (array .values , dtype = dtype )
17781795
1779- def get_duck_array (self ) -> np .ndarray :
1796+ def get_duck_array (self ) -> np .ndarray | PandasExtensionArray :
1797+ # We return an PandasExtensionArray wrapper type that satisfies
1798+ # duck array protocols. This is what's needed for tests to pass.
1799+ if pd .api .types .is_extension_array_dtype (self .array ):
1800+ return PandasExtensionArray (self .array .array )
17801801 return np .asarray (self )
17811802
17821803 @property
17831804 def shape (self ) -> _Shape :
17841805 return (len (self .array ),)
17851806
1786- def _convert_scalar (self , item ):
1807+ def _convert_scalar (self , item ) -> np . ndarray :
17871808 if item is pd .NaT :
17881809 # work around the impossibility of casting NaT with asarray
17891810 # note: it probably would be better in general to return
@@ -1799,7 +1820,10 @@ def _convert_scalar(self, item):
17991820 # numpy fails to convert pd.Timestamp to np.datetime64[ns]
18001821 item = np .asarray (item .to_datetime64 ())
18011822 elif self .dtype != object :
1802- item = np .asarray (item , dtype = self .dtype )
1823+ dtype = self .dtype
1824+ if pd .api .types .is_extension_array_dtype (dtype ):
1825+ dtype = get_valid_numpy_dtype (self .array )
1826+ item = np .asarray (item , dtype = cast (np .dtype , dtype ))
18031827
18041828 # as for numpy.ndarray indexing, we always want the result to be
18051829 # a NumPy array.
@@ -1902,6 +1926,12 @@ def copy(self, deep: bool = True) -> Self:
19021926 array = self .array .copy (deep = True ) if deep else self .array
19031927 return type (self )(array , self ._dtype )
19041928
1929+ @property
1930+ def nbytes (self ) -> int :
1931+ if pd .api .types .is_extension_array_dtype (self .dtype ):
1932+ return self .array .nbytes
1933+ return cast (np .dtype , self .dtype ).itemsize * len (self .array )
1934+
19051935
19061936class PandasMultiIndexingAdapter (PandasIndexingAdapter ):
19071937 """Handles explicit indexing for a pandas.MultiIndex.
@@ -1914,23 +1944,27 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter):
19141944 __slots__ = ("_dtype" , "adapter" , "array" , "level" )
19151945
19161946 array : pd .MultiIndex
1917- _dtype : np .dtype
1947+ _dtype : np .dtype | pd . api . extensions . ExtensionDtype
19181948 level : str | None
19191949
19201950 def __init__ (
19211951 self ,
19221952 array : pd .MultiIndex ,
1923- dtype : DTypeLike = None ,
1953+ dtype : DTypeLike | pd . api . extensions . ExtensionDtype | None = None ,
19241954 level : str | None = None ,
19251955 ):
19261956 super ().__init__ (array , dtype )
19271957 self .level = level
19281958
19291959 def __array__ (
1930- self , dtype : np .typing .DTypeLike = None , / , * , copy : bool | None = None
1960+ self ,
1961+ dtype : DTypeLike | None = None ,
1962+ / ,
1963+ * ,
1964+ copy : bool | None = None ,
19311965 ) -> np .ndarray :
19321966 if dtype is None :
1933- dtype = self .dtype
1967+ dtype = cast ( np . dtype , self .dtype )
19341968 if self .level is not None :
19351969 return np .asarray (
19361970 self .array .get_level_values (self .level ).values , dtype = dtype
0 commit comments