diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index e1a2a0142c52e..dd4ef98ada07b 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -107,6 +107,8 @@ from pandas._libs.tslibs.period cimport is_period_object from pandas._libs.tslibs.timedeltas cimport convert_to_timedelta64 from pandas._libs.tslibs.timezones cimport tz_compare +from pandas.core.dtypes.base import _registry + # constants that will be compared to potentially arbitrarily large # python int cdef: @@ -1693,6 +1695,10 @@ def infer_dtype(value: object, skipna: bool = True) -> str: if is_interval_array(values): return "interval" + reg_dtype = _registry.match_scalar(val) + if reg_dtype: + return str(reg_dtype) + cnp.PyArray_ITER_RESET(it) for i in range(n): val = PyArray_GETITEM(values, PyArray_ITER_DATA(it)) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index d8a42d83b6c54..ba396295efa5e 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -444,6 +444,18 @@ def _can_fast_transpose(self) -> bool: """ return False + @classmethod + def is_unambiguous_scalar(cls, scalar): + return False + + @classmethod + def construct_from_scalar(cls, scalar): + return cls() + + @property + def is_external_dtype(self) -> bool: + return self.__module__[:8] == "pandas.c" + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" @@ -582,5 +594,13 @@ def find( return None + def match_scalar( + self, scalar: Any + ) -> type_t[ExtensionDtype] | ExtensionDtype | None: + for dtype in self.dtypes: + if dtype.is_unambiguous_scalar(scalar): + return dtype.construct_from_scalar(scalar) + return None + _registry = Registry() diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 6ba07b1761557..bcbd933a33c12 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -44,6 +44,7 @@ LossySetitemError, ) +from pandas.core.dtypes.base import _registry from pandas.core.dtypes.common import ( ensure_int8, ensure_int16, @@ -857,6 +858,10 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]: subtype = infer_dtype_from_scalar(val.left)[0] dtype = IntervalDtype(subtype=subtype, closed=val.closed) + reg_dtype = _registry.match_scalar(val) + if reg_dtype: + dtype = reg_dtype + return dtype, val @@ -913,6 +918,12 @@ def infer_dtype_from_array(arr) -> tuple[DtypeObj, ArrayLike]: inferred = lib.infer_dtype(arr, skipna=False) if inferred in ["string", "bytes", "mixed", "mixed-integer"]: return (np.dtype(np.object_), arr) + elif inferred in ["empty", "integer", "floating", "integer-na", "mixed-integer-float", "datetime", "period", "timedelta", "time", "date"]: + pass + else: + arr_dtype = pandas_dtype_func(inferred) + if isinstance(arr_dtype, ExtensionDtype): + return arr_dtype, arr arr = np.asarray(arr) return arr.dtype, arr diff --git a/pandas/core/series.py b/pandas/core/series.py index 4f79e30f48f3c..65cf8fbb3c5a4 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -501,6 +501,11 @@ def __init__( elif copy: data = data.copy() else: + if dtype is None: + inferred_dtype = infer_dtype_from(data)[0] + if isinstance(inferred_dtype, ExtensionDtype) and inferred_dtype.is_external_dtype: + dtype = inferred_dtype + # import pdb; pdb.set_trace() data = sanitize_array(data, index, dtype, copy) data = SingleBlockManager.from_array(data, index, refs=refs) diff --git a/pandas/tests/dtypes/cast/test_infer_dtype.py b/pandas/tests/dtypes/cast/test_infer_dtype.py index 679031a625c2d..0ef5a4344ffd5 100644 --- a/pandas/tests/dtypes/cast/test_infer_dtype.py +++ b/pandas/tests/dtypes/cast/test_infer_dtype.py @@ -23,7 +23,33 @@ Timestamp, date_range, ) - +from pandas.core.dtypes.dtypes import ExtensionDtype, register_extension_dtype + + +class MockScalar: + pass + +@register_extension_dtype +class MockDtype(ExtensionDtype): + @property + def name(self): + return "MockDtype" + def is_unambiguous_scalar(scalar): + if isinstance(scalar, MockScalar): + return True + return False + + @classmethod + def construct_from_string(cls, string: str): + if not isinstance(string, str): + raise TypeError( + f"'construct_from_string' expects a string, got {type(string)}" + ) + + if string == cls.__name__: + return cls() + else: + raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'") def test_infer_dtype_from_int_scalar(any_int_numpy_dtype): # Test that infer_dtype_from_scalar is @@ -157,6 +183,7 @@ def test_infer_dtype_from_scalar_errors(): (np.datetime64("2016-01-01"), np.dtype("M8[s]")), (Timestamp("20160101"), np.dtype("M8[s]")), (Timestamp("20160101", tz="UTC"), "datetime64[s, UTC]"), + (MockScalar(), MockDtype()) ], ) def test_infer_dtype_from_scalar(value, expected, using_infer_string): @@ -189,6 +216,7 @@ def test_infer_dtype_from_scalar(value, expected, using_infer_string): Series(date_range("20160101", periods=3, tz="US/Eastern")), "datetime64[ns, US/Eastern]", ), + ([MockScalar()], MockDtype()) ], ) def test_infer_dtype_from_array(arr, expected, using_infer_string): diff --git a/pandas/tests/dtypes/test_inference.py b/pandas/tests/dtypes/test_inference.py index da444b55490f0..691db81c67010 100644 --- a/pandas/tests/dtypes/test_inference.py +++ b/pandas/tests/dtypes/test_inference.py @@ -75,6 +75,7 @@ FloatingArray, IntegerArray, ) +from pandas.core.dtypes.dtypes import ExtensionDtype, register_extension_dtype @pytest.fixture(params=[True, False], ids=str) @@ -2025,3 +2026,19 @@ def test_find_result_type_int_int(right, result): def test_find_result_type_floats(right, result): left_dtype = np.dtype("float16") assert find_result_type(left_dtype, right) == result + +def test_infer_dtype_extensiondtype(): + class MockScalar: + pass + + @register_extension_dtype + class MockDtype(ExtensionDtype): + @property + def name(self): + return "MockDtype" + def is_unambiguous_scalar(scalar): + if isinstance(scalar, MockScalar): + return True + return False + arr = [MockScalar()] + assert lib.infer_dtype(arr, skipna=True) == "MockDtype" \ No newline at end of file diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py index 1771a4dfdb71f..aeec100d6931e 100644 --- a/pandas/tests/series/test_constructors.py +++ b/pandas/tests/series/test_constructors.py @@ -44,10 +44,62 @@ from pandas.core.arrays import ( IntegerArray, IntervalArray, - period_array, + period_array,ExtensionArray ) from pandas.core.internals.blocks import NumpyBlock +from pandas.core.dtypes.dtypes import ExtensionDtype, register_extension_dtype + +class MockScalar: + pass + +@register_extension_dtype +class MockDtype(ExtensionDtype): + type = MockScalar + @property + def name(self): + return "MockDtype" + def is_unambiguous_scalar(scalar): + if isinstance(scalar, MockScalar): + return True + return False + + @classmethod + def construct_from_string(cls, string: str): + if not isinstance(string, str): + raise TypeError( + f"'construct_from_string' expects a string, got {type(string)}" + ) + + if string == cls.__name__: + return cls() + else: + raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'") + + @classmethod + def construct_array_type(cls): + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return MockArray + + @property + def is_external_dtype(self): + return True + + +from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +class MockArray(NDArrayBackedExtensionArray): + dtype = MockDtype() + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + scalars = np.ndarray([0 for i in scalars]) + return cls(scalars, "O") + class TestSeriesConstructors: def test_from_ints_with_non_nano_dt64_dtype(self, index_or_series): @@ -152,6 +204,13 @@ def test_scalar_extension_dtype(self, ea_scalar_and_dtype): assert ser.dtype == ea_dtype tm.assert_series_equal(ser, expected) + + def test_unambiguous_scalar(self): + ea_scalar, ea_dtype = MockScalar(), MockDtype() + + ser = Series(ea_scalar, index=range(3)) + assert ser.dtype == ea_dtype + def test_constructor(self, datetime_series, using_infer_string): empty_series = Series() assert datetime_series.index._is_all_dates