@@ -43,44 +43,52 @@ def f(x, *args):
4343 )
4444 tm .assert_series_equal (result , expected )
4545
46+ @pytest .mark .parametrize (
47+ "data" , [DataFrame (np .eye (5 )), Series (range (5 ), name = "foo" )]
48+ )
4649 def test_numba_vs_cython_rolling_methods (
47- self , nogil , parallel , nopython , arithmetic_numba_supported_operators
50+ self , data , nogil , parallel , nopython , arithmetic_numba_supported_operators
4851 ):
4952
5053 method = arithmetic_numba_supported_operators
5154
5255 engine_kwargs = {"nogil" : nogil , "parallel" : parallel , "nopython" : nopython }
5356
54- df = DataFrame (np .eye (5 ))
55- roll = df .rolling (2 )
57+ roll = data .rolling (2 )
5658 result = getattr (roll , method )(engine = "numba" , engine_kwargs = engine_kwargs )
5759 expected = getattr (roll , method )(engine = "cython" )
5860
5961 # Check the cache
60- assert (getattr (np , f"nan{ method } " ), "Rolling_apply_single" ) in NUMBA_FUNC_CACHE
62+ if method != "mean" :
63+ assert (
64+ getattr (np , f"nan{ method } " ),
65+ "Rolling_apply_single" ,
66+ ) in NUMBA_FUNC_CACHE
6167
62- tm .assert_frame_equal (result , expected )
68+ tm .assert_equal (result , expected )
6369
70+ @pytest .mark .parametrize ("data" , [DataFrame (np .eye (5 )), Series (range (5 ))])
6471 def test_numba_vs_cython_expanding_methods (
65- self , nogil , parallel , nopython , arithmetic_numba_supported_operators
72+ self , data , nogil , parallel , nopython , arithmetic_numba_supported_operators
6673 ):
6774
6875 method = arithmetic_numba_supported_operators
6976
7077 engine_kwargs = {"nogil" : nogil , "parallel" : parallel , "nopython" : nopython }
7178
72- df = DataFrame (np .eye (5 ))
73- expand = df .expanding ()
79+ data = DataFrame (np .eye (5 ))
80+ expand = data .expanding ()
7481 result = getattr (expand , method )(engine = "numba" , engine_kwargs = engine_kwargs )
7582 expected = getattr (expand , method )(engine = "cython" )
7683
7784 # Check the cache
78- assert (
79- getattr (np , f"nan{ method } " ),
80- "Expanding_apply_single" ,
81- ) in NUMBA_FUNC_CACHE
85+ if method != "mean" :
86+ assert (
87+ getattr (np , f"nan{ method } " ),
88+ "Expanding_apply_single" ,
89+ ) in NUMBA_FUNC_CACHE
8290
83- tm .assert_frame_equal (result , expected )
91+ tm .assert_equal (result , expected )
8492
8593 @pytest .mark .parametrize ("jit" , [True , False ])
8694 def test_cache_apply (self , jit , nogil , parallel , nopython ):
0 commit comments