Skip to content

Commit ee8967f

Browse files
authored
REF: move towards making _apply_blockwise actually block-wise (#35730)
* REF: move towards making _apply_blockwise actually block-wise * mypy fixup * mypy fixup * Series->_constructor * dummy commit to force CI
1 parent 23b1717 commit ee8967f

File tree

1 file changed

+52
-23
lines changed

1 file changed

+52
-23
lines changed

pandas/core/window/rolling.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,23 @@
66
from functools import partial
77
import inspect
88
from textwrap import dedent
9-
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
9+
from typing import (
10+
TYPE_CHECKING,
11+
Callable,
12+
Dict,
13+
List,
14+
Optional,
15+
Set,
16+
Tuple,
17+
Type,
18+
Union,
19+
)
1020

1121
import numpy as np
1222

1323
from pandas._libs.tslibs import BaseOffset, to_offset
1424
import pandas._libs.window.aggregations as window_aggregations
15-
from pandas._typing import ArrayLike, Axis, FrameOrSeries, Label
25+
from pandas._typing import ArrayLike, Axis, FrameOrSeriesUnion, Label
1626
from pandas.compat._optional import import_optional_dependency
1727
from pandas.compat.numpy import function as nv
1828
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc
@@ -55,6 +65,9 @@
5565
)
5666
from pandas.core.window.numba_ import generate_numba_apply_func
5767

68+
if TYPE_CHECKING:
69+
from pandas import Series
70+
5871

5972
def calculate_center_offset(window) -> int:
6073
"""
@@ -145,7 +158,7 @@ class _Window(PandasObject, ShallowMixin, SelectionMixin):
145158

146159
def __init__(
147160
self,
148-
obj: FrameOrSeries,
161+
obj: FrameOrSeriesUnion,
149162
window=None,
150163
min_periods: Optional[int] = None,
151164
center: bool = False,
@@ -219,7 +232,7 @@ def _validate_get_window_bounds_signature(window: BaseIndexer) -> None:
219232
f"get_window_bounds"
220233
)
221234

222-
def _create_blocks(self, obj: FrameOrSeries):
235+
def _create_blocks(self, obj: FrameOrSeriesUnion):
223236
"""
224237
Split data into blocks & return conformed data.
225238
"""
@@ -381,7 +394,7 @@ def _wrap_result(self, result, block=None, obj=None):
381394
return type(obj)(result, index=index, columns=block.columns)
382395
return result
383396

384-
def _wrap_results(self, results, obj, skipped: List[int]) -> FrameOrSeries:
397+
def _wrap_results(self, results, obj, skipped: List[int]) -> FrameOrSeriesUnion:
385398
"""
386399
Wrap the results.
387400
@@ -394,22 +407,23 @@ def _wrap_results(self, results, obj, skipped: List[int]) -> FrameOrSeries:
394407
"""
395408
from pandas import Series, concat
396409

410+
if obj.ndim == 1:
411+
if not results:
412+
raise DataError("No numeric types to aggregate")
413+
assert len(results) == 1
414+
return Series(results[0], index=obj.index, name=obj.name)
415+
397416
exclude: List[Label] = []
398-
if obj.ndim == 2:
399-
orig_blocks = list(obj._to_dict_of_blocks(copy=False).values())
400-
for i in skipped:
401-
exclude.extend(orig_blocks[i].columns)
402-
else:
403-
orig_blocks = [obj]
417+
orig_blocks = list(obj._to_dict_of_blocks(copy=False).values())
418+
for i in skipped:
419+
exclude.extend(orig_blocks[i].columns)
404420

405421
kept_blocks = [blk for i, blk in enumerate(orig_blocks) if i not in skipped]
406422

407423
final = []
408424
for result, block in zip(results, kept_blocks):
409425

410-
result = self._wrap_result(result, block=block, obj=obj)
411-
if result.ndim == 1:
412-
return result
426+
result = type(obj)(result, index=obj.index, columns=block.columns)
413427
final.append(result)
414428

415429
exclude = exclude or []
@@ -488,13 +502,31 @@ def _get_window_indexer(self, window: int) -> BaseIndexer:
488502
return VariableWindowIndexer(index_array=self._on.asi8, window_size=window)
489503
return FixedWindowIndexer(window_size=window)
490504

505+
def _apply_series(self, homogeneous_func: Callable[..., ArrayLike]) -> "Series":
506+
"""
507+
Series version of _apply_blockwise
508+
"""
509+
_, obj = self._create_blocks(self._selected_obj)
510+
values = obj.values
511+
512+
try:
513+
values = self._prep_values(obj.values)
514+
except (TypeError, NotImplementedError) as err:
515+
raise DataError("No numeric types to aggregate") from err
516+
517+
result = homogeneous_func(values)
518+
return obj._constructor(result, index=obj.index, name=obj.name)
519+
491520
def _apply_blockwise(
492521
self, homogeneous_func: Callable[..., ArrayLike]
493-
) -> FrameOrSeries:
522+
) -> FrameOrSeriesUnion:
494523
"""
495524
Apply the given function to the DataFrame broken down into homogeneous
496525
sub-frames.
497526
"""
527+
if self._selected_obj.ndim == 1:
528+
return self._apply_series(homogeneous_func)
529+
498530
# This isn't quite blockwise, since `blocks` is actually a collection
499531
# of homogenenous DataFrames.
500532
blocks, obj = self._create_blocks(self._selected_obj)
@@ -505,12 +537,9 @@ def _apply_blockwise(
505537
try:
506538
values = self._prep_values(b.values)
507539

508-
except (TypeError, NotImplementedError) as err:
509-
if isinstance(obj, ABCDataFrame):
510-
skipped.append(i)
511-
continue
512-
else:
513-
raise DataError("No numeric types to aggregate") from err
540+
except (TypeError, NotImplementedError):
541+
skipped.append(i)
542+
continue
514543

515544
result = homogeneous_func(values)
516545
results.append(result)
@@ -2234,7 +2263,7 @@ def _apply(
22342263
def _constructor(self):
22352264
return Rolling
22362265

2237-
def _create_blocks(self, obj: FrameOrSeries):
2266+
def _create_blocks(self, obj: FrameOrSeriesUnion):
22382267
"""
22392268
Split data into blocks & return conformed data.
22402269
"""
@@ -2275,7 +2304,7 @@ def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer:
22752304
if isinstance(self.window, BaseIndexer):
22762305
rolling_indexer = type(self.window)
22772306
indexer_kwargs = self.window.__dict__
2278-
assert isinstance(indexer_kwargs, dict)
2307+
assert isinstance(indexer_kwargs, dict) # for mypy
22792308
# We'll be using the index of each group later
22802309
indexer_kwargs.pop("index_array", None)
22812310
elif self.is_freq_type:

0 commit comments

Comments
 (0)