diff --git a/pandas/_libs/tslibs/timestamps.pyi b/pandas/_libs/tslibs/timestamps.pyi index ecddd83322bbf..0f7800480837c 100644 --- a/pandas/_libs/tslibs/timestamps.pyi +++ b/pandas/_libs/tslibs/timestamps.pyi @@ -131,9 +131,17 @@ class Timestamp(datetime): def utcoffset(self) -> timedelta | None: ... def tzname(self) -> str | None: ... def dst(self) -> timedelta | None: ... + # error: Argument 1 of "__le__" is incompatible with supertype "date"; + # supertype defines the argument type as "date" def __le__(self, other: datetime) -> bool: ... # type: ignore + # error: Argument 1 of "__lt__" is incompatible with supertype "date"; + # supertype defines the argument type as "date" def __lt__(self, other: datetime) -> bool: ... # type: ignore + # error: Argument 1 of "__ge__" is incompatible with supertype "date"; + # supertype defines the argument type as "date" def __ge__(self, other: datetime) -> bool: ... # type: ignore + # error: Argument 1 of "__gt__" is incompatible with supertype "date"; + # supertype defines the argument type as "date" def __gt__(self, other: datetime) -> bool: ... # type: ignore # error: Signature of "__add__" incompatible with supertype "date"/"datetime" @overload # type: ignore[override] diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index fa07b5fea5ea3..737de8a636b7d 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -480,7 +480,7 @@ def find( dtype_type = type(dtype) else: dtype_type = dtype - if issubclass(dtype_type, ExtensionDtype): + if dtype_type in self.dtypes: # cast needed here as mypy doesn't know we have figured # out it is an ExtensionDtype or type_t[ExtensionDtype] return cast("ExtensionDtype | type_t[ExtensionDtype]", dtype) diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 6776064342db0..802109c41567f 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1426,7 +1426,7 @@ def is_1d_only_ea_dtype(dtype: DtypeObj | None) -> bool: ) -def is_extension_array_dtype(arr_or_dtype) -> bool: +def is_extension_array_dtype(arr_or_dtype: object) -> bool: """ Check if an object is a pandas extension array type. @@ -1476,6 +1476,10 @@ def is_extension_array_dtype(arr_or_dtype) -> bool: return True elif isinstance(dtype, np.dtype): return False + elif isinstance(dtype, type) and issubclass(dtype, ExtensionDtype): + return True + elif not isinstance(dtype, str): + return False else: return registry.find(dtype) is not None diff --git a/pandas/tests/arrays/test_array.py b/pandas/tests/arrays/test_array.py index 329d28c263ff2..cb1e42258336b 100644 --- a/pandas/tests/arrays/test_array.py +++ b/pandas/tests/arrays/test_array.py @@ -386,6 +386,7 @@ def registry_without_decimal(): def test_array_not_registered(registry_without_decimal): # check we aren't on it assert registry.find("decimal") is None + assert registry.find(DecimalDtype) is None data = [decimal.Decimal("1"), decimal.Decimal("2")] result = pd.array(data, dtype=DecimalDtype)