Skip to content

Commit d2de2c7

Browse files
authored
Use backend in ds fixture (#5411)
* Promote backend test fixture to conftest Also adds an example of parameterizing a test in dataset.py * Add backend to ds fixture * Fix tests The result is less elegant, hopefully temporarily
1 parent 5a14d7d commit d2de2c7

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

xarray/tests/test_dataset.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6078,11 +6078,6 @@ def test_query(self, backend, engine, parser):
60786078
# pytest tests — new tests should go here, rather than in the class.
60796079

60806080

6081-
@pytest.fixture(params=[None])
6082-
def data_set(request):
6083-
return create_test_data(request.param)
6084-
6085-
60866081
@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2])))
60876082
def test_isin(test_elements, backend):
60886083
expected = Dataset(
@@ -6153,17 +6148,18 @@ def test_constructor_raises_with_invalid_coords(unaligned_coords):
61536148
xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords)
61546149

61556150

6156-
def test_dir_expected_attrs(data_set):
6151+
@pytest.mark.parametrize("ds", [3], indirect=True)
6152+
def test_dir_expected_attrs(ds):
61576153

61586154
some_expected_attrs = {"pipe", "mean", "isnull", "var1", "dim2", "numbers"}
6159-
result = dir(data_set)
6155+
result = dir(ds)
61606156
assert set(result) >= some_expected_attrs
61616157

61626158

6163-
def test_dir_non_string(data_set):
6159+
def test_dir_non_string(ds):
61646160
# add a numbered key to ensure this doesn't break dir
6165-
data_set[5] = "foo"
6166-
result = dir(data_set)
6161+
ds[5] = "foo"
6162+
result = dir(ds)
61676163
assert 5 not in result
61686164

61696165
# GH2172
@@ -6173,16 +6169,16 @@ def test_dir_non_string(data_set):
61736169
dir(x2)
61746170

61756171

6176-
def test_dir_unicode(data_set):
6177-
data_set["unicode"] = "uni"
6178-
result = dir(data_set)
6172+
def test_dir_unicode(ds):
6173+
ds["unicode"] = "uni"
6174+
result = dir(ds)
61796175
assert "unicode" in result
61806176

61816177

61826178
@pytest.fixture(params=[1])
6183-
def ds(request):
6179+
def ds(request, backend):
61846180
if request.param == 1:
6185-
return Dataset(
6181+
ds = Dataset(
61866182
dict(
61876183
z1=(["y", "x"], np.random.randn(2, 8)),
61886184
z2=(["time", "y"], np.random.randn(10, 2)),
@@ -6194,21 +6190,29 @@ def ds(request):
61946190
y=range(2),
61956191
),
61966192
)
6197-
6198-
if request.param == 2:
6199-
return Dataset(
6200-
{
6201-
"z1": (["time", "y"], np.random.randn(10, 2)),
6202-
"z2": (["time"], np.random.randn(10)),
6203-
"z3": (["x", "time"], np.random.randn(8, 10)),
6204-
},
6205-
{
6206-
"x": ("x", np.linspace(0, 1.0, 8)),
6207-
"time": ("time", np.linspace(0, 1.0, 10)),
6208-
"c": ("y", ["a", "b"]),
6209-
"y": range(2),
6210-
},
6193+
elif request.param == 2:
6194+
ds = Dataset(
6195+
dict(
6196+
z1=(["time", "y"], np.random.randn(10, 2)),
6197+
z2=(["time"], np.random.randn(10)),
6198+
z3=(["x", "time"], np.random.randn(8, 10)),
6199+
),
6200+
dict(
6201+
x=("x", np.linspace(0, 1.0, 8)),
6202+
time=("time", np.linspace(0, 1.0, 10)),
6203+
c=("y", ["a", "b"]),
6204+
y=range(2),
6205+
),
62116206
)
6207+
elif request.param == 3:
6208+
ds = create_test_data()
6209+
else:
6210+
raise ValueError
6211+
6212+
if backend == "dask":
6213+
return ds.chunk()
6214+
6215+
return ds
62126216

62136217

62146218
def test_coarsen_absent_dims_error(ds):
@@ -6526,6 +6530,7 @@ def test_rolling_properties(ds):
65266530
@pytest.mark.parametrize("center", (True, False, None))
65276531
@pytest.mark.parametrize("min_periods", (1, None))
65286532
@pytest.mark.parametrize("key", ("z1", "z2"))
6533+
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
65296534
def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key):
65306535
bn = pytest.importorskip("bottleneck", minversion="1.1")
65316536

@@ -6551,13 +6556,15 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key):
65516556

65526557

65536558
@requires_numbagg
6559+
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
65546560
def test_rolling_exp(ds):
65556561

65566562
result = ds.rolling_exp(time=10, window_type="span").mean()
65576563
assert isinstance(result, Dataset)
65586564

65596565

65606566
@requires_numbagg
6567+
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
65616568
def test_rolling_exp_keep_attrs(ds):
65626569

65636570
attrs_global = {"attrs": "global"}

0 commit comments

Comments
 (0)