Skip to content

Commit 3be09bc

Browse files
committed
Fix apply to only call func once on the first column/row
1 parent 911e19b commit 3be09bc

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

pandas/_libs/reduction.pyx

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ cdef class Reducer:
107107

108108
result = np.empty(self.nresults, dtype='O')
109109
it = <flatiter>PyArray_IterNew(result)
110+
partial_result = None
110111

111112
try:
112113
for i in range(self.nresults):
@@ -134,21 +135,33 @@ cdef class Reducer:
134135
res = self.f(chunk)
135136

136137
# TODO: reason for not squeezing here?
137-
res = _extract_result(res, squeeze=False)
138+
extracted_res = _extract_result(res, squeeze=False)
138139
if i == 0:
139140
# On the first pass, we check the output shape to see
140141
# if this looks like a reduction.
141-
_check_result_array(res, len(self.dummy))
142+
# if it does not, return the computed value to be used by the pure python implementation,
143+
# so the function won't be called twice on the same object (and side effects would occur twice)
144+
try:
145+
_check_result_array(extracted_res, len(self.dummy))
146+
except ValueError as err:
147+
if "Function does not reduce" not in str(err):
148+
# catch only the specific exception
149+
raise
142150

143-
PyArray_SETITEM(result, PyArray_ITER_DATA(it), res)
151+
partial_result = copy(res)
152+
break
153+
154+
155+
PyArray_SETITEM(result, PyArray_ITER_DATA(it), extracted_res)
144156
chunk.data = chunk.data + self.increment
145157
PyArray_ITER_NEXT(it)
158+
146159
finally:
147160
# so we don't free the wrong memory
148161
chunk.data = dummy_buf
149162

150163
result = maybe_convert_objects(result)
151-
return result
164+
return result, partial_result
152165

153166

154167
cdef class _BaseGrouper:

pandas/core/apply.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
from pandas.core.construction import create_series_with_explicit_dtype
2020

21+
from pandas.core.series import Series
22+
from pandas import DataFrame
23+
2124
if TYPE_CHECKING:
2225
from pandas import DataFrame, Series, Index
2326

@@ -220,14 +223,13 @@ def apply_empty_result(self):
220223

221224
def apply_raw(self):
222225
""" apply to the values as a numpy array """
223-
try:
224-
result = libreduction.compute_reduction(self.values, self.f, axis=self.axis)
225-
except ValueError as err:
226-
if "Function does not reduce" not in str(err):
227-
# catch only ValueError raised intentionally in libreduction
228-
raise
229-
# We expect np.apply_along_axis to give a two-dimensional result, or
230-
# also raise.
226+
result, partial_result = libreduction.compute_reduction(
227+
self.values, self.f, axis=self.axis
228+
)
229+
230+
# A non None partial_result means that the reduction was unsuccessful
231+
# We expect np.apply_along_axis to give a two-dimensional result, or raise.
232+
if partial_result is not None:
231233
result = np.apply_along_axis(self.f, self.axis, self.values)
232234

233235
# TODO: mixed type case
@@ -265,6 +267,7 @@ def apply_broadcast(self, target: "DataFrame") -> "DataFrame":
265267

266268
def apply_standard(self):
267269

270+
partial_result = None
268271
# try to reduce first (by default)
269272
# this only matters if the reduction in values is of different dtype
270273
# e.g. if we want to apply to a SparseFrame, then can't directly reduce
@@ -292,13 +295,9 @@ def apply_standard(self):
292295
)
293296

294297
try:
295-
result = libreduction.compute_reduction(
298+
result, partial_result = libreduction.compute_reduction(
296299
values, self.f, axis=self.axis, dummy=dummy, labels=labels
297300
)
298-
except ValueError as err:
299-
if "Function does not reduce" not in str(err):
300-
# catch only ValueError raised intentionally in libreduction
301-
raise
302301
except TypeError:
303302
# e.g. test_apply_ignore_failures we just ignore
304303
if not self.ignore_failures:
@@ -307,23 +306,36 @@ def apply_standard(self):
307306
# reached via numexpr; fall back to python implementation
308307
pass
309308
else:
310-
return self.obj._constructor_sliced(result, index=labels)
309+
# this means that the reduction was successful
310+
if partial_result is None:
311+
return self.obj._constructor_sliced(result, index=labels)
312+
else:
313+
if isinstance(partial_result, Series):
314+
partial_result = DataFrame.infer_objects(partial_result)
311315

312316
# compute the result using the series generator
313-
results, res_index = self.apply_series_generator()
317+
results, res_index = self.apply_series_generator(partial_result)
314318

315319
# wrap results
316320
return self.wrap_results(results, res_index)
317321

318-
def apply_series_generator(self) -> Tuple[ResType, "Index"]:
322+
def apply_series_generator(self, partial_result=None) -> Tuple[ResType, "Index"]:
319323
series_gen = self.series_generator
320324
res_index = self.result_index
321325

322326
keys = []
323327
results = {}
328+
329+
# If a partial result was already computed, use it instead of running on the first element again
330+
series_gen_enumeration = enumerate(series_gen)
331+
if partial_result is not None:
332+
i, v = next(series_gen_enumeration)
333+
results[i] = partial_result
334+
keys.append(v.name)
335+
324336
if self.ignore_failures:
325337
successes = []
326-
for i, v in enumerate(series_gen):
338+
for i, v in series_gen_enumeration:
327339
try:
328340
results[i] = self.f(v)
329341
except Exception:
@@ -337,7 +349,8 @@ def apply_series_generator(self) -> Tuple[ResType, "Index"]:
337349
res_index = res_index.take(successes)
338350

339351
else:
340-
for i, v in enumerate(series_gen):
352+
for i, v in series_gen_enumeration:
353+
341354
results[i] = self.f(v)
342355
keys.append(v.name)
343356

pandas/tests/frame/test_apply.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,9 @@ def apply_list(row):
718718

719719
def test_apply_noreduction_tzaware_object(self):
720720
# https://github.com/pandas-dev/pandas/issues/31505
721-
df = pd.DataFrame({"foo": [pd.Timestamp("2020", tz="UTC")]}, dtype="object")
721+
df = pd.DataFrame(
722+
{"foo": [pd.Timestamp("2020", tz="UTC")]}, dtype="datetime64[ns, UTC]"
723+
)
722724
result = df.apply(lambda x: x)
723725
tm.assert_frame_equal(result, df)
724726
result = df.apply(lambda x: x.copy())

0 commit comments

Comments
 (0)