Skip to content

Commit 5e793d4

Browse files
committed
BUG: TimedeltaIndex.intersection
Fixes pandas-dev#17391
1 parent 27ebb3e commit 5e793d4

File tree

6 files changed

+225
-98
lines changed

6 files changed

+225
-98
lines changed

pandas/core/indexes/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,12 @@ def is_monotonic(self):
13871387
""" alias for is_monotonic_increasing (deprecated) """
13881388
return self.is_monotonic_increasing
13891389

1390+
@property
1391+
def _is_strictly_monotonic(self):
1392+
""" Checks if the index is sorted """
1393+
return (self._is_strictly_monotonic_increasing or
1394+
self._is_strictly_monotonic_decreasing)
1395+
13901396
@property
13911397
def is_monotonic_increasing(self):
13921398
"""

pandas/core/indexes/datetimelike.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
4848
from pandas.core.indexes.base import Index, _index_shared_docs
49+
from pandas.tseries.offsets import index_offsets_equal
4950
from pandas.util._decorators import Appender, cache_readonly
5051
import pandas.core.dtypes.concat as _concat
5152
import pandas.tseries.frequencies as frequencies
@@ -879,6 +880,94 @@ def astype(self, dtype, copy=True):
879880
raise TypeError(msg.format(name=type(self).__name__, dtype=dtype))
880881
return super(DatetimeIndexOpsMixin, self).astype(dtype, copy=copy)
881882

883+
def _intersect_ascending(self, other):
884+
# to make our life easier, "sort" the two ranges
885+
if self[0] <= other[0]:
886+
left, right = self, other
887+
else:
888+
left, right = other, self
889+
890+
end = min(left[-1], right[-1])
891+
start = right[0]
892+
893+
if end < start:
894+
return []
895+
return left.values[slice(*left.slice_locs(start, end))]
896+
897+
def _intersect_descending(self, other):
898+
# this is essentially a flip of _intersect_ascending
899+
if self[0] >= other[0]:
900+
left, right = self, other
901+
else:
902+
left, right = other, self
903+
904+
start = min(left[0], right[0])
905+
end = right[-1]
906+
907+
if end > start:
908+
return Index()
909+
return left.values[slice(*left.slice_locs(start, end))]
910+
911+
def intersection(self, other):
912+
"""
913+
Specialized intersection for DateTimeIndexOpsMixin objects.
914+
May be much faster than Index.intersection.
915+
916+
Parameters
917+
----------
918+
other : Index or array-like
919+
920+
Returns
921+
-------
922+
Index
923+
A shallow copied intersection between the two things passed in
924+
"""
925+
self._assert_can_do_setop(other)
926+
927+
if self.equals(other):
928+
return self._get_consensus_name(other)
929+
930+
lengths = len(self), len(other)
931+
if lengths[0] == 0:
932+
return self
933+
if lengths[1] == 0:
934+
return other
935+
936+
if not isinstance(other, Index):
937+
result = Index.intersection(self, other)
938+
return result
939+
elif (index_offsets_equal(self, other) or
940+
(not self._is_strictly_monotonic or
941+
not other._is_strictly_monotonic)):
942+
result = Index.intersection(self, other)
943+
result = self._shallow_copy(result._values, name=result.name,
944+
tz=getattr(self, 'tz', None),
945+
freq=None
946+
)
947+
if result.freq is None:
948+
result.offset = frequencies.to_offset(result.inferred_freq)
949+
return result
950+
951+
# handle intersecting things like this
952+
# idx1 = pd.to_timedelta((1, 2, 3, 4, 5, 6, 7, 8), unit='s')
953+
# idx2 = pd.to_timedelta((2, 3, 4, 8), unit='s')
954+
if lengths[0] != lengths[1] and (
955+
max(self) != max(other) or min(self) != min(other)):
956+
return Index.intersection(self, other)
957+
958+
# coerce into same order
959+
self_ascending = self.is_monotonic_increasing
960+
if self_ascending != other.is_monotonic_increasing:
961+
other = other.sort_values(ascending=self_ascending)
962+
963+
if self_ascending:
964+
intersected_slice = self._intersect_ascending(other)
965+
else:
966+
intersected_slice = self._intersect_descending(other)
967+
968+
intersected = self._shallow_copy(intersected_slice)
969+
return intersected._get_consensus_name(other)
970+
882971

883972
def _ensure_datetimelike_to_i8(other):
884973
""" helper for coercing an input scalar or array to i8 """

pandas/core/indexes/datetimes.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,62 +1125,6 @@ def _wrap_union_result(self, other, result):
11251125
raise ValueError('Passed item and index have different timezone')
11261126
return self._simple_new(result, name=name, freq=None, tz=self.tz)
11271127

1128-
def intersection(self, other):
1129-
"""
1130-
Specialized intersection for DatetimeIndex objects. May be much faster
1131-
than Index.intersection
1132-
1133-
Parameters
1134-
----------
1135-
other : DatetimeIndex or array-like
1136-
1137-
Returns
1138-
-------
1139-
y : Index or DatetimeIndex
1140-
"""
1141-
self._assert_can_do_setop(other)
1142-
if not isinstance(other, DatetimeIndex):
1143-
try:
1144-
other = DatetimeIndex(other)
1145-
except (TypeError, ValueError):
1146-
pass
1147-
result = Index.intersection(self, other)
1148-
if isinstance(result, DatetimeIndex):
1149-
if result.freq is None:
1150-
result.freq = to_offset(result.inferred_freq)
1151-
return result
1152-
1153-
elif (other.freq is None or self.freq is None or
1154-
other.freq != self.freq or
1155-
not other.freq.isAnchored() or
1156-
(not self.is_monotonic or not other.is_monotonic)):
1157-
result = Index.intersection(self, other)
1158-
result = self._shallow_copy(result._values, name=result.name,
1159-
tz=result.tz, freq=None)
1160-
if result.freq is None:
1161-
result.freq = to_offset(result.inferred_freq)
1162-
return result
1163-
1164-
if len(self) == 0:
1165-
return self
1166-
if len(other) == 0:
1167-
return other
1168-
# to make our life easier, "sort" the two ranges
1169-
if self[0] <= other[0]:
1170-
left, right = self, other
1171-
else:
1172-
left, right = other, self
1173-
1174-
end = min(left[-1], right[-1])
1175-
start = right[0]
1176-
1177-
if end < start:
1178-
return type(self)(data=[])
1179-
else:
1180-
lslice = slice(*left.slice_locs(start, end))
1181-
left_chunk = left.values[lslice]
1182-
return self._shallow_copy(left_chunk)
1183-
11841128
def _parsed_string_to_bounds(self, reso, parsed):
11851129
"""
11861130
Calculate datetime bounds for parsed time string and its resolution.

pandas/core/indexes/timedeltas.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -441,48 +441,6 @@ def _wrap_union_result(self, other, result):
441441
name = self.name if self.name == other.name else None
442442
return self._simple_new(result, name=name, freq=None)
443443

444-
def intersection(self, other):
445-
"""
446-
Specialized intersection for TimedeltaIndex objects. May be much faster
447-
than Index.intersection
448-
449-
Parameters
450-
----------
451-
other : TimedeltaIndex or array-like
452-
453-
Returns
454-
-------
455-
y : Index or TimedeltaIndex
456-
"""
457-
self._assert_can_do_setop(other)
458-
if not isinstance(other, TimedeltaIndex):
459-
try:
460-
other = TimedeltaIndex(other)
461-
except (TypeError, ValueError):
462-
pass
463-
result = Index.intersection(self, other)
464-
return result
465-
466-
if len(self) == 0:
467-
return self
468-
if len(other) == 0:
469-
return other
470-
# to make our life easier, "sort" the two ranges
471-
if self[0] <= other[0]:
472-
left, right = self, other
473-
else:
474-
left, right = other, self
475-
476-
end = min(left[-1], right[-1])
477-
start = right[0]
478-
479-
if end < start:
480-
return type(self)(data=[])
481-
else:
482-
lslice = slice(*left.slice_locs(start, end))
483-
left_chunk = left.values[lslice]
484-
return self._shallow_copy(left_chunk)
485-
486444
def _maybe_promote(self, other):
487445
if other.inferred_type == 'timedelta':
488446
other = TimedeltaIndex(other)

pandas/tests/indexes/timedeltas/test_setops.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pandas as pd
45
import pandas.util.testing as tm
@@ -73,3 +74,97 @@ def test_intersection_bug_1708(self):
7374
result = index_1 & index_2
7475
expected = timedelta_range('1 day 01:00:00', periods=3, freq='h')
7576
tm.assert_index_equal(result, expected)
77+
78+
79+
@pytest.mark.parametrize('idx1,idx2,expected', [
80+
(pd.to_timedelta(range(2, 6), unit='s'),
81+
pd.to_timedelta(range(3), unit='s'),
82+
TimedeltaIndex(['00:00:002'])),
83+
(pd.to_timedelta(range(3), unit='s'),
84+
pd.to_timedelta(range(2, 6), unit='s'),
85+
TimedeltaIndex(['00:00:002'])),
86+
])
87+
def test_intersection_intersects_ascending(idx1, idx2, expected):
88+
result = idx1.intersection(idx2)
89+
assert result.equals(expected)
90+
91+
92+
@pytest.mark.parametrize('idx1,idx2,expected', [
93+
(pd.to_timedelta(range(6, 3, -1), unit='s'),
94+
pd.to_timedelta(range(5, 1, -1), unit='s'),
95+
TimedeltaIndex(['00:00:05', '00:00:04'])),
96+
(pd.to_timedelta(range(5, 1, -1), unit='s'),
97+
pd.to_timedelta(range(6, 3, -1), unit='s'),
98+
TimedeltaIndex(['00:00:05', '00:00:04'])),
99+
])
100+
def test_intersection_intersects_descending(idx1, idx2, expected):
101+
# GH 17391
102+
result = idx1.intersection(idx2)
103+
assert result.equals(expected)
104+
105+
106+
def test_intersection_intersects_descending_no_intersect():
107+
idx1 = pd.to_timedelta(range(6, 4, -1), unit='s')
108+
idx2 = pd.to_timedelta(range(4, 1, -1), unit='s')
109+
result = idx1.intersection(idx2)
110+
assert len(result) == 0
111+
112+
113+
def test_intersection_intersects_len_1():
114+
idx1 = pd.to_timedelta(range(1, 2), unit='s')
115+
idx2 = pd.to_timedelta(range(1, 0, -1), unit='s')
116+
intersection = idx1.intersection(idx2)
117+
expected = TimedeltaIndex(['00:00:01'],
118+
dtype='timedelta64[ns]')
119+
tm.assert_index_equal(intersection, expected)
120+
121+
122+
def test_intersection_can_intersect_self():
123+
idx = pd.to_timedelta(range(1, 2), unit='s')
124+
result = idx.intersection(idx)
125+
tm.assert_index_equal(idx, result)
126+
127+
128+
def test_intersection_not_sorted():
129+
idx1 = pd.to_timedelta((1, 3, 2, 5, 4), unit='s')
130+
idx2 = pd.to_timedelta((1, 2, 3, 5, 4), unit='s')
131+
result = idx1.intersection(idx2)
132+
expected = idx1
133+
tm.assert_index_equal(result, expected)
134+
135+
136+
def test_intersection_not_unique():
137+
idx1 = pd.to_timedelta((1, 2, 2, 3, 3, 5), unit='s')
138+
idx2 = pd.to_timedelta((1, 2, 3, 4), unit='s')
139+
result = idx1.intersection(idx2)
140+
expected = pd.to_timedelta((1, 2, 2, 3, 3), unit='s')
141+
tm.assert_index_equal(result, expected)
142+
143+
result = idx2.intersection(idx1)
144+
expected = pd.to_timedelta((1, 2, 2, 3, 3), unit='s')
145+
tm.assert_index_equal(result, expected)
146+
147+
148+
@pytest.mark.parametrize("index1, index2, expected", [
149+
(pd.to_timedelta((1, 2, 3, 4, 5, 6, 7, 8), unit='s'),
150+
pd.to_timedelta((2, 3, 4, 8), unit='s'),
151+
pd.to_timedelta((2, 3, 4, 8), unit='s')),
152+
(pd.to_timedelta((1, 2, 3, 4, 5), unit='s'),
153+
pd.to_timedelta((2, 3, 4), unit='s'),
154+
pd.to_timedelta((2, 3, 4), unit='s')),
155+
(pd.to_timedelta((2, 4, 5, 6), unit='s'),
156+
pd.to_timedelta((2, 3, 4), unit='s'),
157+
pd.to_timedelta((2, 4), unit='s')),
158+
])
159+
def test_intersection_different_lengths(index1, index2, expected):
160+
def intersect(idx1, idx2, expected):
161+
result = idx1.intersection(idx2)
162+
tm.assert_index_equal(result, expected)
163+
result = idx2.intersection(idx1)
164+
tm.assert_index_equal(result, expected)
165+
166+
intersect(index1, index2, expected)
167+
intersect(index1.sort_values(ascending=False),
168+
index2.sort_values(ascending=False),
169+
expected.sort_values(ascending=False)
170+
)

pandas/tseries/offsets.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,41 @@ def wrapper(self, other):
112112
return wrapper
113113

114114

115+
def apply_index_wraps(func):
116+
@functools.wraps(func)
117+
def wrapper(self, other):
118+
result = func(self, other)
119+
if self.normalize:
120+
result = result.to_period('D').to_timestamp()
121+
return result
122+
return wrapper
123+
124+
125+
def _is_normalized(dt):
126+
if (dt.hour != 0 or dt.minute != 0 or dt.second != 0 or
127+
dt.microsecond != 0 or getattr(dt, 'nanosecond', 0) != 0):
128+
return False
129+
return True
130+
131+
132+
def index_offsets_equal(first, second):
133+
"""
134+
Checks if the two indexes have an offset, and if they equal each other
135+
Parameters
136+
----------
137+
first: Index
138+
second: Index
139+
140+
Returns
141+
-------
142+
bool
143+
"""
144+
first = getattr(first, 'freq', None)
145+
second = getattr(second, 'freq', None)
146+
are_offsets_equal = True
147+
if first is None or second is None or first != second:
148+
are_offsets_equal = False
149+
return are_offsets_equal
115150
# ---------------------------------------------------------------------
116151
# DateOffset
117152

0 commit comments

Comments
 (0)