1+ import operator
12from operator import le , lt
23import textwrap
34from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
1213 IntervalMixin ,
1314 intervals_to_interval_bounds ,
1415)
16+ from pandas ._libs .missing import NA
1517from pandas ._typing import ArrayLike , Dtype
1618from pandas .compat .numpy import function as nv
1719from pandas .util ._decorators import Appender
4850from pandas .core .construction import array , extract_array
4951from pandas .core .indexers import check_array_indexer
5052from pandas .core .indexes .base import ensure_index
51- from pandas .core .ops import unpack_zerodim_and_defer
53+ from pandas .core .ops import invalid_comparison , unpack_zerodim_and_defer
5254
5355if TYPE_CHECKING :
5456 from pandas import Index
@@ -520,16 +522,15 @@ def __setitem__(self, key, value):
520522 self ._left [key ] = value_left
521523 self ._right [key ] = value_right
522524
523- @unpack_zerodim_and_defer ("__eq__" )
524- def __eq__ (self , other ):
525+ def _cmp_method (self , other , op ):
525526 # ensure pandas array for list-like and eliminate non-interval scalars
526527 if is_list_like (other ):
527528 if len (self ) != len (other ):
528529 raise ValueError ("Lengths must match to compare" )
529530 other = array (other )
530531 elif not isinstance (other , Interval ):
531532 # non-interval scalar -> no matches
532- return np . zeros ( len ( self ), dtype = bool )
533+ return invalid_comparison ( self , other , op )
533534
534535 # determine the dtype of the elements we want to compare
535536 if isinstance (other , Interval ):
@@ -543,35 +544,79 @@ def __eq__(self, other):
543544 # extract intervals if we have interval categories with matching closed
544545 if is_interval_dtype (other_dtype ):
545546 if self .closed != other .categories .closed :
546- return np .zeros (len (self ), dtype = bool )
547+ return invalid_comparison (self , other , op )
548+
547549 other = other .categories .take (
548550 other .codes , allow_fill = True , fill_value = other .categories ._na_value
549551 )
550552
551553 # interval-like -> need same closed and matching endpoints
552554 if is_interval_dtype (other_dtype ):
553555 if self .closed != other .closed :
554- return np .zeros (len (self ), dtype = bool )
555- return (self ._left == other .left ) & (self ._right == other .right )
556+ return invalid_comparison (self , other , op )
557+ elif not isinstance (other , Interval ):
558+ other = type (self )(other )
559+
560+ if op is operator .eq :
561+ return (self ._left == other .left ) & (self ._right == other .right )
562+ elif op is operator .ne :
563+ return (self ._left != other .left ) | (self ._right != other .right )
564+ elif op is operator .gt :
565+ return (self ._left > other .left ) | (
566+ (self ._left == other .left ) & (self ._right > other .right )
567+ )
568+ elif op is operator .ge :
569+ return (self == other ) | (self > other )
570+ elif op is operator .lt :
571+ return (self ._left < other .left ) | (
572+ (self ._left == other .left ) & (self ._right < other .right )
573+ )
574+ else :
575+ # operator.lt
576+ return (self == other ) | (self < other )
556577
557578 # non-interval/non-object dtype -> no matches
558579 if not is_object_dtype (other_dtype ):
559- return np . zeros ( len ( self ), dtype = bool )
580+ return invalid_comparison ( self , other , op )
560581
561582 # object dtype -> iteratively check for intervals
562583 result = np .zeros (len (self ), dtype = bool )
563584 for i , obj in enumerate (other ):
564- # need object to be an Interval with same closed and endpoints
565- if (
566- isinstance ( obj , Interval )
567- and self . closed == obj . closed
568- and self . _left [ i ] == obj . left
569- and self . _right [ i ] == obj . right
570- ):
571- result [ i ] = True
572-
585+ try :
586+ result [ i ] = op ( self [ i ], obj )
587+ except TypeError :
588+ if obj is NA :
589+ # comparison with np.nan returns NA
590+ # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
591+ result [ i ] = op is operator . ne
592+ else :
593+ raise
573594 return result
574595
596+ @unpack_zerodim_and_defer ("__eq__" )
597+ def __eq__ (self , other ):
598+ return self ._cmp_method (other , operator .eq )
599+
600+ @unpack_zerodim_and_defer ("__ne__" )
601+ def __ne__ (self , other ):
602+ return self ._cmp_method (other , operator .ne )
603+
604+ @unpack_zerodim_and_defer ("__gt__" )
605+ def __gt__ (self , other ):
606+ return self ._cmp_method (other , operator .gt )
607+
608+ @unpack_zerodim_and_defer ("__ge__" )
609+ def __ge__ (self , other ):
610+ return self ._cmp_method (other , operator .ge )
611+
612+ @unpack_zerodim_and_defer ("__lt__" )
613+ def __lt__ (self , other ):
614+ return self ._cmp_method (other , operator .lt )
615+
616+ @unpack_zerodim_and_defer ("__le__" )
617+ def __le__ (self , other ):
618+ return self ._cmp_method (other , operator .le )
619+
575620 def fillna (self , value = None , method = None , limit = None ):
576621 """
577622 Fill NA/NaN values using the specified method.
0 commit comments