@@ -143,12 +143,6 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
143143 'DatetimeIndex' )
144144 method = 'values'
145145
146- def _interp_limit (invalid , fw_limit , bw_limit ):
147- "Get idx of values that won't be filled b/c they exceed the limits."
148- for x in np .where (invalid )[0 ]:
149- if invalid [max (0 , x - fw_limit ):x + bw_limit + 1 ].all ():
150- yield x
151-
152146 valid_limit_directions = ['forward' , 'backward' , 'both' ]
153147 limit_direction = limit_direction .lower ()
154148 if limit_direction not in valid_limit_directions :
@@ -180,21 +174,29 @@ def _interp_limit(invalid, fw_limit, bw_limit):
180174
181175 # default limit is unlimited GH #16282
182176 if limit is None :
183- limit = len (xvalues )
177+ # limit = len(xvalues)
178+ pass
184179 elif not is_integer (limit ):
185180 raise ValueError ('Limit must be an integer' )
186181 elif limit < 1 :
187182 raise ValueError ('Limit must be greater than 0' )
188183
189184 # each possible limit_direction
190- if limit_direction == 'forward' :
185+ # TODO: do we need sorted?
186+ if limit_direction == 'forward' and limit is not None :
191187 violate_limit = sorted (start_nans |
192188 set (_interp_limit (invalid , limit , 0 )))
193- elif limit_direction == 'backward' :
189+ elif limit_direction == 'forward' :
190+ violate_limit = sorted (start_nans )
191+ elif limit_direction == 'backward' and limit is not None :
194192 violate_limit = sorted (end_nans |
195193 set (_interp_limit (invalid , 0 , limit )))
196- elif limit_direction == 'both' :
194+ elif limit_direction == 'backward' :
195+ violate_limit = sorted (end_nans )
196+ elif limit_direction == 'both' and limit is not None :
197197 violate_limit = sorted (_interp_limit (invalid , limit , limit ))
198+ else :
199+ violate_limit = []
198200
199201 xvalues = getattr (xvalues , 'values' , xvalues )
200202 yvalues = getattr (yvalues , 'values' , yvalues )
@@ -630,3 +632,58 @@ def fill_zeros(result, x, y, name, fill):
630632 result = result .reshape (shape )
631633
632634 return result
635+
636+
637+ def _interp_limit (invalid , fw_limit , bw_limit ):
638+ """Get idx of values that won't be filled b/c they exceed the limits.
639+
640+ This is equivalent to the more readable, but slower
641+
642+ .. code-block:: python
643+
644+ for x in np.where(invalid)[0]:
645+ if invalid[max(0, x - fw_limit):x + bw_limit + 1].all():
646+ yield x
647+ """
648+ # handle forward first; the backward direction is the same except
649+ # 1. operate on the reversed array
650+ # 2. subtract the returned indicies from N - 1
651+ N = len (invalid )
652+
653+ def inner (invalid , limit ):
654+ limit = min (limit , N )
655+ windowed = _rolling_window (invalid , limit + 1 ).all (1 )
656+ idx = (set (np .where (windowed )[0 ] + limit ) |
657+ set (np .where ((~ invalid [:limit + 1 ]).cumsum () == 0 )[0 ]))
658+ return idx
659+
660+ if fw_limit == 0 :
661+ f_idx = set (np .where (invalid )[0 ])
662+ else :
663+ f_idx = inner (invalid , fw_limit )
664+
665+ if bw_limit == 0 :
666+ # then we don't even need to care about backwards, just use forwards
667+ return f_idx
668+ else :
669+ b_idx = set (N - 1 - np .asarray (list (inner (invalid [::- 1 ], bw_limit ))))
670+ if fw_limit == 0 :
671+ return b_idx
672+ return f_idx & b_idx
673+
674+
675+ def _rolling_window (a , window ):
676+ """
677+ [True, True, False, True, False], 2 ->
678+
679+ [
680+ [True, True],
681+ [True, False],
682+ [False, True],
683+ [True, False],
684+ ]
685+ """
686+ # https://stackoverflow.com/a/6811241
687+ shape = a .shape [:- 1 ] + (a .shape [- 1 ] - window + 1 , window )
688+ strides = a .strides + (a .strides [- 1 ],)
689+ return np .lib .stride_tricks .as_strided (a , shape = shape , strides = strides )
0 commit comments