diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 76814403af385..6a49f9f670aab 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -6,8 +6,9 @@ import numpy as np -from pandas._libs import NaT, iNaT, lib +from pandas._libs import NaT, iNaT, join as libjoin, lib from pandas._libs.algos import unique_deltas +from pandas._libs.tslibs import timezones from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import Appender, cache_readonly @@ -72,6 +73,32 @@ def method(self, other): return method +def _join_i8_wrapper(joinf, with_indexers: bool = True): + """ + Create the join wrapper methods. + """ + + @staticmethod # type: ignore + def wrapper(left, right): + if isinstance(left, (np.ndarray, ABCIndex, ABCSeries, DatetimeLikeArrayMixin)): + left = left.view("i8") + if isinstance(right, (np.ndarray, ABCIndex, ABCSeries, DatetimeLikeArrayMixin)): + right = right.view("i8") + + results = joinf(left, right) + if with_indexers: + # dtype should be timedelta64[ns] for TimedeltaIndex + # and datetime64[ns] for DatetimeIndex + dtype = left.dtype.base + + join_index, left_indexer, right_indexer = results + join_index = join_index.view(dtype) + return join_index, left_indexer, right_indexer + return results + + return wrapper + + class DatetimeIndexOpsMixin(ExtensionOpsMixin): """ Common ops mixin to support a unified interface datetimelike Index. @@ -208,32 +235,6 @@ def equals(self, other): return np.array_equal(self.asi8, other.asi8) - @staticmethod - def _join_i8_wrapper(joinf, dtype, with_indexers=True): - """ - Create the join wrapper methods. - """ - from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin - - @staticmethod - def wrapper(left, right): - if isinstance( - left, (np.ndarray, ABCIndex, ABCSeries, DatetimeLikeArrayMixin) - ): - left = left.view("i8") - if isinstance( - right, (np.ndarray, ABCIndex, ABCSeries, DatetimeLikeArrayMixin) - ): - right = right.view("i8") - results = joinf(left, right) - if with_indexers: - join_index, left_indexer, right_indexer = results - join_index = join_index.view(dtype) - return join_index, left_indexer, right_indexer - return results - - return wrapper - def _ensure_localized( self, arg, ambiguous="raise", nonexistent="raise", from_utc=False ): @@ -853,6 +854,75 @@ def _can_fast_union(self, other) -> bool: # this will raise return False + # -------------------------------------------------------------------- + # Join Methods + _join_precedence = 10 + + _inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer) + _outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer) + _left_indexer = _join_i8_wrapper(libjoin.left_join_indexer) + _left_indexer_unique = _join_i8_wrapper( + libjoin.left_join_indexer_unique, with_indexers=False + ) + + def join( + self, other, how: str = "left", level=None, return_indexers=False, sort=False + ): + """ + See Index.join + """ + if self._is_convertible_to_index_for_join(other): + try: + other = type(self)(other) + except (TypeError, ValueError): + pass + + this, other = self._maybe_utc_convert(other) + return Index.join( + this, + other, + how=how, + level=level, + return_indexers=return_indexers, + sort=sort, + ) + + def _maybe_utc_convert(self, other): + this = self + if not hasattr(self, "tz"): + return this, other + + if isinstance(other, type(self)): + if self.tz is not None: + if other.tz is None: + raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex") + elif other.tz is not None: + raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex") + + if not timezones.tz_compare(self.tz, other.tz): + this = self.tz_convert("UTC") + other = other.tz_convert("UTC") + return this, other + + @classmethod + def _is_convertible_to_index_for_join(cls, other: Index) -> bool: + """ + return a boolean whether I can attempt conversion to a + DatetimeIndex/TimedeltaIndex + """ + if isinstance(other, cls): + return False + elif len(other) > 0 and other.inferred_type not in ( + "floating", + "mixed-integer", + "integer", + "integer-na", + "mixed-integer-float", + "mixed", + ): + return True + return False + def wrap_arithmetic_op(self, other, result): if result is NotImplemented: diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 53d2ed22cd631..fafa9e95a5963 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -5,7 +5,6 @@ import numpy as np from pandas._libs import NaT, Timestamp, index as libindex, lib, tslib as libts -import pandas._libs.join as libjoin from pandas._libs.tslibs import ccalendar, fields, parsing, timezones from pandas.util._decorators import Appender, Substitution, cache_readonly @@ -32,7 +31,6 @@ import pandas.core.common as com from pandas.core.indexes.base import Index, maybe_extract_name from pandas.core.indexes.datetimelike import ( - DatetimeIndexOpsMixin, DatetimelikeDelegateMixin, DatetimeTimedeltaMixin, ) @@ -195,17 +193,6 @@ class DatetimeIndex(DatetimeTimedeltaMixin, DatetimeDelegateMixin): """ _typ = "datetimeindex" - _join_precedence = 10 - - def _join_i8_wrapper(joinf, **kwargs): - return DatetimeIndexOpsMixin._join_i8_wrapper(joinf, dtype="M8[ns]", **kwargs) - - _inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer) - _outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer) - _left_indexer = _join_i8_wrapper(libjoin.left_join_indexer) - _left_indexer_unique = _join_i8_wrapper( - libjoin.left_join_indexer_unique, with_indexers=False - ) _engine_type = libindex.DatetimeEngine _supports_partial_string_indexing = True @@ -645,54 +632,6 @@ def snap(self, freq="S"): # we know it conforms; skip check return DatetimeIndex._simple_new(snapped, name=self.name, tz=self.tz, freq=freq) - def join( - self, other, how: str = "left", level=None, return_indexers=False, sort=False - ): - """ - See Index.join - """ - if ( - not isinstance(other, DatetimeIndex) - and len(other) > 0 - and other.inferred_type - not in ( - "floating", - "integer", - "integer-na", - "mixed-integer", - "mixed-integer-float", - "mixed", - ) - ): - try: - other = DatetimeIndex(other) - except (TypeError, ValueError): - pass - - this, other = self._maybe_utc_convert(other) - return Index.join( - this, - other, - how=how, - level=level, - return_indexers=return_indexers, - sort=sort, - ) - - def _maybe_utc_convert(self, other): - this = self - if isinstance(other, DatetimeIndex): - if self.tz is not None: - if other.tz is None: - raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex") - elif other.tz is not None: - raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex") - - if not timezones.tz_compare(self.tz, other.tz): - this = self.tz_convert("UTC") - other = other.tz_convert("UTC") - return this, other - def _wrap_joined_index(self, joined, other): name = get_op_result_name(self, other) if ( diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index 65c3ece6000fc..e6790d092778f 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -3,7 +3,7 @@ import numpy as np -from pandas._libs import NaT, Timedelta, index as libindex, join as libjoin, lib +from pandas._libs import NaT, Timedelta, index as libindex, lib from pandas.util._decorators import Appender, Substitution from pandas.core.dtypes.common import ( @@ -121,17 +121,6 @@ class TimedeltaIndex( """ _typ = "timedeltaindex" - _join_precedence = 10 - - def _join_i8_wrapper(joinf, **kwargs): - return DatetimeIndexOpsMixin._join_i8_wrapper(joinf, dtype="m8[ns]", **kwargs) - - _inner_indexer = _join_i8_wrapper(libjoin.inner_join_indexer) - _outer_indexer = _join_i8_wrapper(libjoin.outer_join_indexer) - _left_indexer = _join_i8_wrapper(libjoin.left_join_indexer) - _left_indexer_unique = _join_i8_wrapper( - libjoin.left_join_indexer_unique, with_indexers=False - ) _engine_type = libindex.TimedeltaEngine @@ -294,25 +283,6 @@ def _union(self, other, sort): result._set_freq("infer") return result - def join(self, other, how="left", level=None, return_indexers=False, sort=False): - """ - See Index.join - """ - if _is_convertible_to_index(other): - try: - other = TimedeltaIndex(other) - except (TypeError, ValueError): - pass - - return Index.join( - self, - other, - how=how, - level=level, - return_indexers=return_indexers, - sort=sort, - ) - def _wrap_joined_index(self, joined, other): name = get_op_result_name(self, other) if ( @@ -569,24 +539,6 @@ def delete(self, loc): TimedeltaIndex._add_datetimelike_methods() -def _is_convertible_to_index(other) -> bool: - """ - return a boolean whether I can attempt conversion to a TimedeltaIndex - """ - if isinstance(other, TimedeltaIndex): - return True - elif len(other) > 0 and other.inferred_type not in ( - "floating", - "mixed-integer", - "integer", - "integer-na", - "mixed-integer-float", - "mixed", - ): - return True - return False - - def timedelta_range( start=None, end=None, periods=None, freq=None, name=None, closed=None ) -> TimedeltaIndex: