Skip to content

Commit 95a86a9

Browse files
authored
PERF: DTA/TDA _simple_new disallow i8 values (#40116)
1 parent 4503dac commit 95a86a9

File tree

10 files changed

+37
-38
lines changed

10 files changed

+37
-38
lines changed

pandas/core/arrays/datetimelike.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,15 +465,15 @@ def view(self, dtype: Optional[Dtype] = None) -> ArrayLike:
465465
dtype = pandas_dtype(dtype)
466466
if isinstance(dtype, (PeriodDtype, DatetimeTZDtype)):
467467
cls = dtype.construct_array_type()
468-
return cls._simple_new(self.asi8, dtype=dtype)
468+
return cls(self.asi8, dtype=dtype)
469469
elif dtype == "M8[ns]":
470470
from pandas.core.arrays import DatetimeArray
471471

472-
return DatetimeArray._simple_new(self.asi8, dtype=dtype)
472+
return DatetimeArray(self.asi8, dtype=dtype)
473473
elif dtype == "m8[ns]":
474474
from pandas.core.arrays import TimedeltaArray
475475

476-
return TimedeltaArray._simple_new(self.asi8.view("m8[ns]"), dtype=dtype)
476+
return TimedeltaArray(self.asi8, dtype=dtype)
477477
return self._ndarray.view(dtype=dtype)
478478

479479
# ------------------------------------------------------------------
@@ -1102,10 +1102,10 @@ def _add_timedeltalike_scalar(self, other):
11021102
return type(self)(new_values, dtype=self.dtype)
11031103

11041104
inc = delta_to_nanoseconds(other)
1105-
new_values = checked_add_with_arr(self.asi8, inc, arr_mask=self._isnan).view(
1106-
"i8"
1107-
)
1105+
new_values = checked_add_with_arr(self.asi8, inc, arr_mask=self._isnan)
1106+
new_values = new_values.view("i8")
11081107
new_values = self._maybe_mask_results(new_values)
1108+
new_values = new_values.view(self._ndarray.dtype)
11091109

11101110
new_freq = None
11111111
if isinstance(self.freq, Tick) or is_period_dtype(self.dtype):
@@ -1700,6 +1700,7 @@ def _round(self, freq, mode, ambiguous, nonexistent):
17001700
nanos = to_offset(freq).nanos
17011701
result = round_nsint64(values, mode, nanos)
17021702
result = self._maybe_mask_results(result, fill_value=iNaT)
1703+
result = result.view(self._ndarray.dtype)
17031704
return self._simple_new(result, dtype=self.dtype)
17041705

17051706
@Appender((_round_doc + _round_example).format(op="round"))

pandas/core/arrays/datetimes.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ def _simple_new(
315315
cls, values, freq: Optional[BaseOffset] = None, dtype=DT64NS_DTYPE
316316
) -> DatetimeArray:
317317
assert isinstance(values, np.ndarray)
318-
if values.dtype != DT64NS_DTYPE:
319-
assert values.dtype == "i8"
320-
values = values.view(DT64NS_DTYPE)
318+
assert values.dtype == DT64NS_DTYPE
321319

322320
result = object.__new__(cls)
323321
result._ndarray = values
@@ -439,6 +437,7 @@ def _generate_range(
439437
values = np.array([x.value for x in xdr], dtype=np.int64)
440438

441439
_tz = start.tz if start is not None else end.tz
440+
values = values.view("M8[ns]")
442441
index = cls._simple_new(values, freq=freq, dtype=tz_to_dtype(_tz))
443442

444443
if tz is not None and index.tz is None:
@@ -464,9 +463,8 @@ def _generate_range(
464463
+ start.value
465464
)
466465
dtype = tz_to_dtype(tz)
467-
index = cls._simple_new(
468-
arr.astype("M8[ns]", copy=False), freq=None, dtype=dtype
469-
)
466+
arr = arr.astype("M8[ns]", copy=False)
467+
index = cls._simple_new(arr, freq=None, dtype=dtype)
470468

471469
if not left_closed and len(index) and index[0] == start:
472470
# TODO: overload DatetimeLikeArrayMixin.__getitem__
@@ -476,7 +474,7 @@ def _generate_range(
476474
index = cast(DatetimeArray, index[:-1])
477475

478476
dtype = tz_to_dtype(tz)
479-
return cls._simple_new(index.asi8, freq=freq, dtype=dtype)
477+
return cls._simple_new(index._ndarray, freq=freq, dtype=dtype)
480478

481479
# -----------------------------------------------------------------
482480
# DatetimeLike Interface
@@ -710,7 +708,7 @@ def _add_offset(self, offset):
710708
values = self.tz_localize(None)
711709
else:
712710
values = self
713-
result = offset._apply_array(values)
711+
result = offset._apply_array(values).view("M8[ns]")
714712
result = DatetimeArray._simple_new(result)
715713
result = result.tz_localize(self.tz)
716714

@@ -833,7 +831,7 @@ def tz_convert(self, tz):
833831

834832
# No conversion since timestamps are all UTC to begin with
835833
dtype = tz_to_dtype(tz)
836-
return self._simple_new(self.asi8, dtype=dtype, freq=self.freq)
834+
return self._simple_new(self._ndarray, dtype=dtype, freq=self.freq)
837835

838836
@dtl.ravel_compat
839837
def tz_localize(self, tz, ambiguous="raise", nonexistent="raise"):

pandas/core/arrays/timedeltas.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,11 @@ def _simple_new(
230230
) -> TimedeltaArray:
231231
assert dtype == TD64NS_DTYPE, dtype
232232
assert isinstance(values, np.ndarray), type(values)
233-
if values.dtype != TD64NS_DTYPE:
234-
assert values.dtype == "i8"
235-
values = values.view(TD64NS_DTYPE)
233+
assert values.dtype == TD64NS_DTYPE
236234

237235
result = object.__new__(cls)
238236
result._ndarray = values
239-
result._freq = to_offset(freq)
237+
result._freq = freq
240238
result._dtype = TD64NS_DTYPE
241239
return result
242240

@@ -318,7 +316,7 @@ def _generate_range(cls, start, end, periods, freq, closed=None):
318316
if not right_closed:
319317
index = index[:-1]
320318

321-
return cls._simple_new(index, freq=freq)
319+
return cls._simple_new(index.view("m8[ns]"), freq=freq)
322320

323321
# ----------------------------------------------------------------
324322
# DatetimeLike Interface

pandas/core/dtypes/cast.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def maybe_downcast_to_dtype(
288288
i8values = result.astype("i8", copy=False)
289289
cls = dtype.construct_array_type()
290290
# equiv: DatetimeArray(i8values).tz_localize("UTC").tz_convert(dtype.tz)
291-
result = cls._simple_new(i8values, dtype=dtype)
291+
dt64values = i8values.view("M8[ns]")
292+
result = cls._simple_new(dt64values, dtype=dtype)
292293
else:
293294
result = result.astype(dtype)
294295

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def _ea_wrap_cython_operation(
542542
return res_values
543543

544544
res_values = res_values.astype("i8", copy=False)
545-
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
545+
result = type(orig_values)(res_values, dtype=orig_values.dtype)
546546
return result
547547

548548
elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):

pandas/core/indexes/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def view(self, cls=None):
820820
arr = self._data.view("i8")
821821
idx_cls = self._dtype_to_subclass(dtype)
822822
arr_cls = idx_cls._data_cls
823-
arr = arr_cls._simple_new(self._data.view("i8"), dtype=dtype)
823+
arr = arr_cls(self._data.view("i8"), dtype=dtype)
824824
return idx_cls._simple_new(arr, name=self.name)
825825

826826
result = self._data.view(cls)

pandas/core/indexes/datetimelike.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def wrapper(left, right):
100100
join_index = orig_left._from_backing_data(join_index)
101101

102102
return join_index, left_indexer, right_indexer
103+
103104
return results
104105

105106
return wrapper
@@ -645,7 +646,8 @@ def _get_join_freq(self, other):
645646

646647
def _wrap_joined_index(self, joined: np.ndarray, other):
647648
assert other.dtype == self.dtype, (other.dtype, self.dtype)
648-
649+
assert joined.dtype == "i8" or joined.dtype == self.dtype, joined.dtype
650+
joined = joined.view(self._data._ndarray.dtype)
649651
result = super()._wrap_joined_index(joined, other)
650652
result._data._freq = self._get_join_freq(other)
651653
return result

pandas/core/nanops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1743,8 +1743,9 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike:
17431743
result = result.view(orig_dtype)
17441744
else:
17451745
# DatetimeArray
1746+
# TODO: have this case go through a DTA method?
17461747
result = type(values)._simple_new( # type: ignore[attr-defined]
1747-
result, dtype=orig_dtype
1748+
result.view("M8[ns]"), dtype=orig_dtype
17481749
)
17491750

17501751
elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):

pandas/tests/arrays/test_datetimelike.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,10 @@ def arr1d(self):
8585
arr = self.array_cls(data, freq="D")
8686
return arr
8787

88-
def test_compare_len1_raises(self):
88+
def test_compare_len1_raises(self, arr1d):
8989
# make sure we raise when comparing with different lengths, specific
9090
# to the case where one has length-1, which numpy would broadcast
91-
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
92-
93-
arr = self.array_cls._simple_new(data, freq="D")
91+
arr = arr1d
9492
idx = self.index_cls(arr)
9593

9694
with pytest.raises(ValueError, match="Lengths must match"):
@@ -153,7 +151,9 @@ def test_take(self):
153151
data = np.arange(100, dtype="i8") * 24 * 3600 * 10 ** 9
154152
np.random.shuffle(data)
155153

156-
arr = self.array_cls._simple_new(data, freq="D")
154+
freq = None if self.array_cls is not PeriodArray else "D"
155+
156+
arr = self.array_cls(data, freq=freq)
157157
idx = self.index_cls._simple_new(arr)
158158

159159
takers = [1, 4, 94]
@@ -172,7 +172,7 @@ def test_take(self):
172172
def test_take_fill_raises(self, fill_value):
173173
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
174174

175-
arr = self.array_cls._simple_new(data, freq="D")
175+
arr = self.array_cls(data, freq="D")
176176

177177
msg = f"value should be a '{arr._scalar_type.__name__}' or 'NaT'. Got"
178178
with pytest.raises(TypeError, match=msg):
@@ -181,7 +181,7 @@ def test_take_fill_raises(self, fill_value):
181181
def test_take_fill(self):
182182
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
183183

184-
arr = self.array_cls._simple_new(data, freq="D")
184+
arr = self.array_cls(data, freq="D")
185185

186186
result = arr.take([-1, 1], allow_fill=True, fill_value=None)
187187
assert result[0] is pd.NaT
@@ -202,10 +202,8 @@ def test_take_fill_str(self, arr1d):
202202
with pytest.raises(TypeError, match=msg):
203203
arr1d.take([-1, 1], allow_fill=True, fill_value="foo")
204204

205-
def test_concat_same_type(self):
206-
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
207-
208-
arr = self.array_cls._simple_new(data, freq="D")
205+
def test_concat_same_type(self, arr1d):
206+
arr = arr1d
209207
idx = self.index_cls(arr)
210208
idx = idx.insert(0, pd.NaT)
211209
arr = self.array_cls(idx)

pandas/tests/indexes/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_get_unique_index(self, index_flat):
175175
vals = index[[0] * 5]._data
176176
vals[0] = pd.NaT
177177
elif needs_i8_conversion(index.dtype):
178-
vals = index.asi8[[0] * 5]
178+
vals = index._data._ndarray[[0] * 5]
179179
vals[0] = iNaT
180180
else:
181181
vals = index.values[[0] * 5]
@@ -184,7 +184,7 @@ def test_get_unique_index(self, index_flat):
184184
vals_unique = vals[:2]
185185
if index.dtype.kind in ["m", "M"]:
186186
# i.e. needs_i8_conversion but not period_dtype, as above
187-
vals = type(index._data)._simple_new(vals, dtype=index.dtype)
187+
vals = type(index._data)(vals, dtype=index.dtype)
188188
vals_unique = type(index._data)._simple_new(vals_unique, dtype=index.dtype)
189189
idx_nan = index._shallow_copy(vals)
190190
idx_unique_nan = index._shallow_copy(vals_unique)

0 commit comments

Comments
 (0)