Skip to content

Commit 2038d53

Browse files
authored
ENH: PandasArray ops use core.ops functions (#36484)
1 parent cf6ad46 commit 2038d53

File tree

8 files changed

+186
-91
lines changed

8 files changed

+186
-91
lines changed

pandas/core/arrays/numpy_.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,19 +362,29 @@ def __invert__(self):
362362

363363
@classmethod
364364
def _create_arithmetic_method(cls, op):
365+
366+
pd_op = ops.get_array_op(op)
367+
365368
@ops.unpack_zerodim_and_defer(op.__name__)
366369
def arithmetic_method(self, other):
367370
if isinstance(other, cls):
368371
other = other._ndarray
369372

370-
with np.errstate(all="ignore"):
371-
result = op(self._ndarray, other)
373+
result = pd_op(self._ndarray, other)
372374

373-
if op is divmod:
375+
if op is divmod or op is ops.rdivmod:
374376
a, b = result
375-
return cls(a), cls(b)
376-
377-
return cls(result)
377+
if isinstance(a, np.ndarray):
378+
# for e.g. op vs TimedeltaArray, we may already
379+
# have an ExtensionArray, in which case we do not wrap
380+
return cls(a), cls(b)
381+
return a, b
382+
383+
if isinstance(result, np.ndarray):
384+
# for e.g. multiplication vs TimedeltaArray, we may already
385+
# have an ExtensionArray, in which case we do not wrap
386+
return cls(result)
387+
return result
378388

379389
return compat.set_function_name(arithmetic_method, f"__{op.__name__}__", cls)
380390

pandas/tests/arithmetic/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from pandas import DataFrame, Index, Series
7+
from pandas import DataFrame, Index, Series, array as pd_array
88
import pandas._testing as tm
99

1010

@@ -49,12 +49,12 @@ def assert_invalid_comparison(left, right, box):
4949
----------
5050
left : np.ndarray, ExtensionArray, Index, or Series
5151
right : object
52-
box : {pd.DataFrame, pd.Series, pd.Index, tm.to_array}
52+
box : {pd.DataFrame, pd.Series, pd.Index, pd.array, tm.to_array}
5353
"""
5454
# Not for tznaive-tzaware comparison
5555

5656
# Note: not quite the same as how we do this for tm.box_expected
57-
xbox = box if box is not Index else np.array
57+
xbox = box if box not in [Index, pd_array] else np.array
5858

5959
result = left == right
6060
expected = xbox(np.zeros(result.shape, dtype=np.bool_))

pandas/tests/arithmetic/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33

44
import pandas as pd
5-
import pandas._testing as tm
65

76
# ------------------------------------------------------------------
87
# Helper Functions
@@ -56,7 +55,7 @@ def one(request):
5655

5756
zeros = [
5857
box_cls([0] * 5, dtype=dtype)
59-
for box_cls in [pd.Index, np.array]
58+
for box_cls in [pd.Index, np.array, pd.array]
6059
for dtype in [np.int64, np.uint64, np.float64]
6160
]
6261
zeros.extend(
@@ -231,7 +230,7 @@ def box(request):
231230
return request.param
232231

233232

234-
@pytest.fixture(params=[pd.Index, pd.Series, pd.DataFrame, tm.to_array], ids=id_func)
233+
@pytest.fixture(params=[pd.Index, pd.Series, pd.DataFrame, pd.array], ids=id_func)
235234
def box_with_array(request):
236235
"""
237236
Fixture to test behavior for Index, Series, DataFrame, and pandas Array

pandas/tests/arithmetic/test_datetime64.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def test_compare_zerodim(self, tz_naive_fixture, box_with_array):
4848
# Test comparison with zero-dimensional array is unboxed
4949
tz = tz_naive_fixture
5050
box = box_with_array
51-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
51+
xbox = (
52+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
53+
)
5254
dti = date_range("20130101", periods=3, tz=tz)
5355

5456
other = np.array(dti.to_numpy()[0])
@@ -135,7 +137,7 @@ def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
135137
# GH#22242, GH#22163 DataFrame considered NaT == ts incorrectly
136138
tz = tz_naive_fixture
137139
box = box_with_array
138-
xbox = box if box is not pd.Index else np.ndarray
140+
xbox = box if box not in [pd.Index, pd.array] else np.ndarray
139141

140142
ts = pd.Timestamp.now(tz)
141143
ser = pd.Series([ts, pd.NaT])
@@ -203,6 +205,8 @@ def test_nat_comparisons(self, dtype, index_or_series, reverse, pair):
203205
def test_comparison_invalid(self, tz_naive_fixture, box_with_array):
204206
# GH#4968
205207
# invalid date/int comparisons
208+
if box_with_array is pd.array:
209+
pytest.xfail("assert_invalid_comparison doesnt handle BooleanArray yet")
206210
tz = tz_naive_fixture
207211
ser = Series(range(5))
208212
ser2 = Series(pd.date_range("20010101", periods=5, tz=tz))
@@ -226,8 +230,12 @@ def test_nat_comparisons_scalar(self, dtype, data, box_with_array):
226230
# dont bother testing ndarray comparison methods as this fails
227231
# on older numpys (since they check object identity)
228232
return
233+
if box_with_array is pd.array and dtype is object:
234+
pytest.xfail("reversed comparisons give BooleanArray, not ndarray")
229235

230-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
236+
xbox = (
237+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
238+
)
231239

232240
left = Series(data, dtype=dtype)
233241
left = tm.box_expected(left, box_with_array)
@@ -299,7 +307,9 @@ def test_timestamp_compare_series(self, left, right):
299307

300308
def test_dt64arr_timestamp_equality(self, box_with_array):
301309
# GH#11034
302-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
310+
xbox = (
311+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
312+
)
303313

304314
ser = pd.Series([pd.Timestamp("2000-01-29 01:59:00"), "NaT"])
305315
ser = tm.box_expected(ser, box_with_array)
@@ -388,7 +398,9 @@ def test_dti_cmp_nat(self, dtype, box_with_array):
388398
# on older numpys (since they check object identity)
389399
return
390400

391-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
401+
xbox = (
402+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
403+
)
392404

393405
left = pd.DatetimeIndex(
394406
[pd.Timestamp("2011-01-01"), pd.NaT, pd.Timestamp("2011-01-03")]

pandas/tests/arithmetic/test_numeric.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def test_compare_invalid(self):
8989
b.name = pd.Timestamp("2000-01-01")
9090
tm.assert_series_equal(a / b, 1 / (b / a))
9191

92-
def test_numeric_cmp_string_numexpr_path(self, box):
92+
def test_numeric_cmp_string_numexpr_path(self, box_with_array):
9393
# GH#36377, GH#35700
94+
box = box_with_array
9495
xbox = box if box is not pd.Index else np.ndarray
9596

9697
obj = pd.Series(np.random.randn(10 ** 5))
@@ -183,10 +184,14 @@ def test_ops_series(self):
183184
],
184185
ids=lambda x: type(x).__name__,
185186
)
186-
def test_numeric_arr_mul_tdscalar(self, scalar_td, numeric_idx, box):
187+
def test_numeric_arr_mul_tdscalar(self, scalar_td, numeric_idx, box_with_array):
187188
# GH#19333
189+
box = box_with_array
190+
if box is pd.array:
191+
pytest.xfail(
192+
"we get a PandasArray[timedelta64[ns]] instead of TimedeltaArray"
193+
)
188194
index = numeric_idx
189-
190195
expected = pd.TimedeltaIndex([pd.Timedelta(days=n) for n in range(5)])
191196

192197
index = tm.box_expected(index, box)
@@ -207,7 +212,11 @@ def test_numeric_arr_mul_tdscalar(self, scalar_td, numeric_idx, box):
207212
],
208213
ids=lambda x: type(x).__name__,
209214
)
210-
def test_numeric_arr_mul_tdscalar_numexpr_path(self, scalar_td, box):
215+
def test_numeric_arr_mul_tdscalar_numexpr_path(self, scalar_td, box_with_array):
216+
box = box_with_array
217+
if box is pd.array:
218+
pytest.xfail("IntegerArray.__mul__ doesnt handle timedeltas")
219+
211220
arr = np.arange(2 * 10 ** 4).astype(np.int64)
212221
obj = tm.box_expected(arr, box, transpose=False)
213222

@@ -220,7 +229,11 @@ def test_numeric_arr_mul_tdscalar_numexpr_path(self, scalar_td, box):
220229
result = scalar_td * obj
221230
tm.assert_equal(result, expected)
222231

223-
def test_numeric_arr_rdiv_tdscalar(self, three_days, numeric_idx, box):
232+
def test_numeric_arr_rdiv_tdscalar(self, three_days, numeric_idx, box_with_array):
233+
box = box_with_array
234+
if box is pd.array:
235+
pytest.xfail("We get PandasArray[td64] instead of TimedeltaArray")
236+
224237
index = numeric_idx[1:3]
225238

226239
expected = TimedeltaIndex(["3 Days", "36 Hours"])
@@ -248,7 +261,11 @@ def test_numeric_arr_rdiv_tdscalar(self, three_days, numeric_idx, box):
248261
pd.offsets.Second(0),
249262
],
250263
)
251-
def test_add_sub_timedeltalike_invalid(self, numeric_idx, other, box):
264+
def test_add_sub_timedeltalike_invalid(self, numeric_idx, other, box_with_array):
265+
box = box_with_array
266+
if box is pd.array:
267+
pytest.xfail("PandasArray[int].__add__ doesnt raise on td64")
268+
252269
left = tm.box_expected(numeric_idx, box)
253270
msg = (
254271
"unsupported operand type|"
@@ -276,16 +293,21 @@ def test_add_sub_timedeltalike_invalid(self, numeric_idx, other, box):
276293
],
277294
)
278295
@pytest.mark.filterwarnings("ignore:elementwise comp:DeprecationWarning")
279-
def test_add_sub_datetimelike_invalid(self, numeric_idx, other, box):
296+
def test_add_sub_datetimelike_invalid(self, numeric_idx, other, box_with_array):
280297
# GH#28080 numeric+datetime64 should raise; Timestamp raises
281298
# NullFrequencyError instead of TypeError so is excluded.
299+
box = box_with_array
282300
left = tm.box_expected(numeric_idx, box)
283301

284-
msg = (
285-
"unsupported operand type|"
286-
"Cannot (add|subtract) NaT (to|from) ndarray|"
287-
"Addition/subtraction of integers and integer-arrays|"
288-
"Concatenation operation is not implemented for NumPy arrays"
302+
msg = "|".join(
303+
[
304+
"unsupported operand type",
305+
"Cannot (add|subtract) NaT (to|from) ndarray",
306+
"Addition/subtraction of integers and integer-arrays",
307+
"Concatenation operation is not implemented for NumPy arrays",
308+
# pd.array vs np.datetime64 case
309+
r"operand type\(s\) all returned NotImplemented from __array_ufunc__",
310+
]
289311
)
290312
with pytest.raises(TypeError, match=msg):
291313
left + other
@@ -568,8 +590,9 @@ class TestMultiplicationDivision:
568590
# __mul__, __rmul__, __div__, __rdiv__, __floordiv__, __rfloordiv__
569591
# for non-timestamp/timedelta/period dtypes
570592

571-
def test_divide_decimal(self, box):
593+
def test_divide_decimal(self, box_with_array):
572594
# resolves issue GH#9787
595+
box = box_with_array
573596
ser = Series([Decimal(10)])
574597
expected = Series([Decimal(5)])
575598

pandas/tests/arithmetic/test_object.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,22 @@ def test_add_extension_scalar(self, other, box_with_array, op):
104104
result = op(arr, other)
105105
tm.assert_equal(result, expected)
106106

107-
def test_objarr_add_str(self, box):
107+
def test_objarr_add_str(self, box_with_array):
108108
ser = pd.Series(["x", np.nan, "x"])
109109
expected = pd.Series(["xa", np.nan, "xa"])
110110

111-
ser = tm.box_expected(ser, box)
112-
expected = tm.box_expected(expected, box)
111+
ser = tm.box_expected(ser, box_with_array)
112+
expected = tm.box_expected(expected, box_with_array)
113113

114114
result = ser + "a"
115115
tm.assert_equal(result, expected)
116116

117-
def test_objarr_radd_str(self, box):
117+
def test_objarr_radd_str(self, box_with_array):
118118
ser = pd.Series(["x", np.nan, "x"])
119119
expected = pd.Series(["ax", np.nan, "ax"])
120120

121-
ser = tm.box_expected(ser, box)
122-
expected = tm.box_expected(expected, box)
121+
ser = tm.box_expected(ser, box_with_array)
122+
expected = tm.box_expected(expected, box_with_array)
123123

124124
result = "a" + ser
125125
tm.assert_equal(result, expected)

pandas/tests/arithmetic/test_period.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class TestPeriodArrayLikeComparisons:
2828

2929
def test_compare_zerodim(self, box_with_array):
3030
# GH#26689 make sure we unbox zero-dimensional arrays
31-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
31+
xbox = (
32+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
33+
)
3234

3335
pi = pd.period_range("2000", periods=4)
3436
other = np.array(pi.to_numpy()[0])
@@ -68,7 +70,7 @@ def test_compare_object_dtype(self, box_with_array, other_box):
6870
pi = pd.period_range("2000", periods=5)
6971
parr = tm.box_expected(pi, box_with_array)
7072

71-
xbox = np.ndarray if box_with_array is pd.Index else box_with_array
73+
xbox = np.ndarray if box_with_array in [pd.Index, pd.array] else box_with_array
7274

7375
other = other_box(pi)
7476

@@ -175,7 +177,9 @@ def test_pi_cmp_period(self):
175177

176178
# TODO: moved from test_datetime64; de-duplicate with version below
177179
def test_parr_cmp_period_scalar2(self, box_with_array):
178-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
180+
xbox = (
181+
box_with_array if box_with_array not in [pd.Index, pd.array] else np.ndarray
182+
)
179183

180184
pi = pd.period_range("2000-01-01", periods=10, freq="D")
181185

@@ -196,7 +200,7 @@ def test_parr_cmp_period_scalar2(self, box_with_array):
196200
@pytest.mark.parametrize("freq", ["M", "2M", "3M"])
197201
def test_parr_cmp_period_scalar(self, freq, box_with_array):
198202
# GH#13200
199-
xbox = np.ndarray if box_with_array is pd.Index else box_with_array
203+
xbox = np.ndarray if box_with_array in [pd.Index, pd.array] else box_with_array
200204

201205
base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq)
202206
base = tm.box_expected(base, box_with_array)
@@ -235,7 +239,7 @@ def test_parr_cmp_period_scalar(self, freq, box_with_array):
235239
@pytest.mark.parametrize("freq", ["M", "2M", "3M"])
236240
def test_parr_cmp_pi(self, freq, box_with_array):
237241
# GH#13200
238-
xbox = np.ndarray if box_with_array is pd.Index else box_with_array
242+
xbox = np.ndarray if box_with_array in [pd.Index, pd.array] else box_with_array
239243

240244
base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq)
241245
base = tm.box_expected(base, box_with_array)
@@ -284,7 +288,7 @@ def test_parr_cmp_pi_mismatched_freq_raises(self, freq, box_with_array):
284288
# TODO: Could parametrize over boxes for idx?
285289
idx = PeriodIndex(["2011", "2012", "2013", "2014"], freq="A")
286290
rev_msg = r"Input has different freq=(M|2M|3M) from PeriodArray\(freq=A-DEC\)"
287-
idx_msg = rev_msg if box_with_array is tm.to_array else msg
291+
idx_msg = rev_msg if box_with_array in [tm.to_array, pd.array] else msg
288292
with pytest.raises(IncompatibleFrequency, match=idx_msg):
289293
base <= idx
290294

@@ -298,7 +302,7 @@ def test_parr_cmp_pi_mismatched_freq_raises(self, freq, box_with_array):
298302

299303
idx = PeriodIndex(["2011", "2012", "2013", "2014"], freq="4M")
300304
rev_msg = r"Input has different freq=(M|2M|3M) from PeriodArray\(freq=4M\)"
301-
idx_msg = rev_msg if box_with_array is tm.to_array else msg
305+
idx_msg = rev_msg if box_with_array in [tm.to_array, pd.array] else msg
302306
with pytest.raises(IncompatibleFrequency, match=idx_msg):
303307
base <= idx
304308

@@ -779,7 +783,7 @@ def test_pi_add_sub_td64_array_tick(self):
779783
@pytest.mark.parametrize("tdi_freq", [None, "H"])
780784
def test_parr_sub_td64array(self, box_with_array, tdi_freq, pi_freq):
781785
box = box_with_array
782-
xbox = box if box is not tm.to_array else pd.Index
786+
xbox = box if box not in [pd.array, tm.to_array] else pd.Index
783787

784788
tdi = TimedeltaIndex(["1 hours", "2 hours"], freq=tdi_freq)
785789
dti = Timestamp("2018-03-07 17:16:40") + tdi

0 commit comments

Comments
 (0)