Skip to content

Commit 0711eb0

Browse files
committed
bugfix.
1 parent 4ee2963 commit 0711eb0

File tree

4 files changed

+83
-57
lines changed

4 files changed

+83
-57
lines changed

xarray/core/concat.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,34 +197,36 @@ def process_subset_opt(opt, subset):
197197
equals[k] = getattr(variables[0], compat)(
198198
var, equiv=lazy_array_equiv
199199
)
200-
if not equals[k]:
200+
if equals[k] is not True:
201+
# exit early if we know these are not equal or that
202+
# equality cannot be determined i.e. one or all of
203+
# the variables wraps a numpy array
201204
break
202205

203-
if equals[k] is not None:
204-
if equals[k] is False:
205-
concat_over.add(k)
206-
continue
207-
208-
# Compare the variable of all datasets vs. the one
209-
# of the first dataset. Perform the minimum amount of
210-
# loads in order to avoid multiple loads from disk
211-
# while keeping the RAM footprint low.
212-
v_lhs = datasets[0].variables[k].load()
213-
# We'll need to know later on if variables are equal.
214-
computed = []
215-
for ds_rhs in datasets[1:]:
216-
v_rhs = ds_rhs.variables[k].compute()
217-
computed.append(v_rhs)
218-
if not getattr(v_lhs, compat)(v_rhs):
219-
concat_over.add(k)
220-
equals[k] = False
221-
# computed variables are not to be re-computed
222-
# again in the future
223-
for ds, v in zip(datasets[1:], computed):
224-
ds.variables[k].data = v.data
225-
break
226-
else:
227-
equals[k] = True
206+
if equals[k] is False:
207+
concat_over.add(k)
208+
209+
elif equals[k] is None:
210+
# Compare the variable of all datasets vs. the one
211+
# of the first dataset. Perform the minimum amount of
212+
# loads in order to avoid multiple loads from disk
213+
# while keeping the RAM footprint low.
214+
v_lhs = datasets[0].variables[k].load()
215+
# We'll need to know later on if variables are equal.
216+
computed = []
217+
for ds_rhs in datasets[1:]:
218+
v_rhs = ds_rhs.variables[k].compute()
219+
computed.append(v_rhs)
220+
if not getattr(v_lhs, compat)(v_rhs):
221+
concat_over.add(k)
222+
equals[k] = False
223+
# computed variables are not to be re-computed
224+
# again in the future
225+
for ds, v in zip(datasets[1:], computed):
226+
ds.variables[k].data = v.data
227+
break
228+
else:
229+
equals[k] = True
228230

229231
elif opt == "all":
230232
concat_over.update(

xarray/core/duck_array_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ def as_shared_dtype(scalars_or_arrays):
176176

177177
def lazy_array_equiv(arr1, arr2):
178178
"""Like array_equal, but doesn't actually compare values.
179-
Returns True or False when equality can be determined without computing.
180-
Returns None when equality cannot determined (e.g. one or both of arr1, arr2 are numpy arrays)
179+
Returns True when arr1, arr2 identical or their dask names are equal.
180+
Returns False when shapes are not equal.
181+
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
182+
or their dask names are not equal
181183
"""
182184
if arr1 is arr2:
183185
return True
@@ -193,6 +195,8 @@ def lazy_array_equiv(arr1, arr2):
193195
# GH3068
194196
if arr1.name == arr2.name:
195197
return True
198+
else:
199+
return None
196200
return None
197201

198202

xarray/core/merge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ def unique_variable(
127127
# first check without comparing values i.e. no computes
128128
for var in variables[1:]:
129129
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
130-
if not equals:
130+
if equals is not True:
131131
break
132132

133-
# now compare values with minimum number of computes
134-
if not equals:
133+
if equals is None:
134+
# now compare values with minimum number of computes
135135
out = out.compute()
136136
for var in variables[1:]:
137137
equals = getattr(out, compat)(var)

xarray/tests/test_dask.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,32 +1202,6 @@ def test_identical_coords_no_computes():
12021202
assert_identical(c, a)
12031203

12041204

1205-
def test_lazy_array_equiv():
1206-
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1207-
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1208-
var1 = lons1.variable
1209-
var2 = lons2.variable
1210-
with raise_if_dask_computes():
1211-
lons1.equals(lons2)
1212-
with raise_if_dask_computes():
1213-
var1.equals(var2 / 2, equiv=lazy_array_equiv)
1214-
assert var1.equals(var2.compute(), equiv=lazy_array_equiv) is None
1215-
assert var1.compute().equals(var2.compute(), equiv=lazy_array_equiv) is None
1216-
1217-
with raise_if_dask_computes():
1218-
assert lons1.equals(lons1.transpose("y", "x"))
1219-
1220-
with raise_if_dask_computes():
1221-
for compat in [
1222-
"broadcast_equals",
1223-
"equals",
1224-
"override",
1225-
"identical",
1226-
"no_conflicts",
1227-
]:
1228-
xr.merge([lons1, lons2], compat=compat)
1229-
1230-
12311205
@pytest.mark.parametrize(
12321206
"obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
12331207
)
@@ -1315,3 +1289,49 @@ def test_normalize_token_with_backend(map_ds):
13151289
map_ds.to_netcdf(tmp_file)
13161290
read = xr.open_dataset(tmp_file)
13171291
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)
1292+
1293+
1294+
@pytest.mark.parametrize(
1295+
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
1296+
)
1297+
def test_lazy_array_equiv_variables(compat):
1298+
var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
1299+
var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
1300+
var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2))
1301+
1302+
with raise_if_dask_computes():
1303+
assert getattr(var1, compat)(var2, equiv=lazy_array_equiv)
1304+
# values are actually equal, but we don't know that till we compute, return None
1305+
with raise_if_dask_computes():
1306+
assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None
1307+
1308+
# shapes are not equal, return False without computes
1309+
with raise_if_dask_computes():
1310+
assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False
1311+
1312+
# if one or both arrays are numpy, return None
1313+
assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None
1314+
assert (
1315+
getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None
1316+
)
1317+
1318+
with raise_if_dask_computes():
1319+
assert getattr(var1, compat)(var2.transpose("y", "x"))
1320+
1321+
1322+
@pytest.mark.parametrize(
1323+
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
1324+
)
1325+
def test_lazy_array_equiv_merge(compat):
1326+
da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1327+
da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1328+
da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x"))
1329+
1330+
with raise_if_dask_computes():
1331+
xr.merge([da1, da2], compat=compat)
1332+
# shapes are not equal; no computes necessary
1333+
with raise_if_dask_computes(max_computes=0):
1334+
with pytest.raises(ValueError):
1335+
xr.merge([da1, da3], compat=compat)
1336+
with raise_if_dask_computes(max_computes=2):
1337+
xr.merge([da1, da2 / 2], compat=compat)

0 commit comments

Comments
 (0)