1+ from __future__ import annotations
2+
13from distutils .version import LooseVersion
2- from typing import TYPE_CHECKING , Generic , Hashable , Mapping , TypeVar , Union
4+ from typing import Generic , Hashable , Mapping , Union
35
46import numpy as np
57
68from .options import _get_keep_attrs
79from .pdcompat import count_not_none
810from .pycompat import is_duck_dask_array
9-
10- if TYPE_CHECKING :
11- from .dataarray import DataArray # noqa: F401
12- from .dataset import Dataset # noqa: F401
13-
14- T_DSorDA = TypeVar ("T_DSorDA" , "DataArray" , "Dataset" )
11+ from .types import T_Xarray
1512
1613
1714def _get_alpha (com = None , span = None , halflife = None , alpha = None ):
@@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
7976 return float (comass )
8077
8178
82- class RollingExp (Generic [T_DSorDA ]):
79+ class RollingExp (Generic [T_Xarray ]):
8380 """
8481 Exponentially-weighted moving window object.
8582 Similar to EWM in pandas
@@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]):
103100
104101 def __init__ (
105102 self ,
106- obj : T_DSorDA ,
103+ obj : T_Xarray ,
107104 windows : Mapping [Hashable , Union [int , float ]],
108105 window_type : str = "span" ,
109106 ):
110- self .obj : T_DSorDA = obj
107+ self .obj : T_Xarray = obj
111108 dim , window = next (iter (windows .items ()))
112109 self .dim = dim
113110 self .alpha = _get_alpha (** {window_type : window })
114111
115- def mean (self , keep_attrs : bool = None ) -> T_DSorDA :
112+ def mean (self , keep_attrs : bool = None ) -> T_Xarray :
116113 """
117114 Exponentially weighted moving average.
118115
@@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA:
139136 move_exp_nanmean , dim = self .dim , alpha = self .alpha , keep_attrs = keep_attrs
140137 )
141138
142- def sum (self , keep_attrs : bool = None ) -> T_DSorDA :
139+ def sum (self , keep_attrs : bool = None ) -> T_Xarray :
143140 """
144141 Exponentially weighted moving sum.
145142
0 commit comments