Skip to content

Commit 7866587

Browse files
authored
Add join='override' (#3175)
* Add join='override' * Add coords='skip_nondim' * Revert "Add coords='skip_nondim'" This reverts commit 8263d38. * black * black2 * join='override' concat tests. * Add whats-new.rst * Improve error message. * da error message. * Refactor + fix edge cases. * Add da test. * more darray tests. * Update docstrings. * Address review comments.
1 parent e678ec9 commit 7866587

File tree

10 files changed

+133
-3
lines changed

10 files changed

+133
-3
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ New functions/methods
5555
Enhancements
5656
~~~~~~~~~~~~
5757

58+
- Added ``join='override'``. This only checks that index sizes are equal among objects and skips
59+
checking indexes for equality. By `Deepak Cherian <https://github.com/dcherian>`_.
5860
- :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg.
5961
It is passed down to :py:func:`~xarray.align`. By `Deepak Cherian <https://github.com/dcherian>`_.
6062
- In :py:meth:`~xarray.Dataset.to_zarr`, passing ``mode`` is not mandatory if

xarray/backends/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def open_mfdataset(
813813
parallel : bool, optional
814814
If True, the open and preprocess steps of this function will be
815815
performed in parallel using ``dask.delayed``. Default is False.
816-
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
816+
join : {'outer', 'inner', 'left', 'right', 'exact, 'override'}, optional
817817
String indicating how to combine differing indexes
818818
(excluding concat_dim) in objects
819819
@@ -823,6 +823,9 @@ def open_mfdataset(
823823
- 'right': use indexes from the last object with each dimension
824824
- 'exact': instead of aligning, raise `ValueError` when indexes to be
825825
aligned are not equal
826+
- 'override': if indexes are of same size, rewrite indexes to be
827+
those of the first object with that dimension. Indexes for the same
828+
dimension must have the same size in all objects.
826829
**kwargs : optional
827830
Additional arguments passed on to :py:func:`xarray.open_dataset`.
828831

xarray/core/alignment.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,34 @@ def _get_joiner(join):
3131
# We cannot return a function to "align" in this case, because it needs
3232
# access to the dimension name to give a good error message.
3333
return None
34+
elif join == "override":
35+
# We rewrite all indexes and then use join='left'
36+
return operator.itemgetter(0)
3437
else:
3538
raise ValueError("invalid value for join: %s" % join)
3639

3740

41+
def _override_indexes(objects, all_indexes, exclude):
42+
for dim, dim_indexes in all_indexes.items():
43+
if dim not in exclude:
44+
lengths = {index.size for index in dim_indexes}
45+
if len(lengths) != 1:
46+
raise ValueError(
47+
"Indexes along dimension %r don't have the same length."
48+
" Cannot use join='override'." % dim
49+
)
50+
51+
objects = list(objects)
52+
for idx, obj in enumerate(objects[1:]):
53+
new_indexes = dict()
54+
for dim in obj.dims:
55+
if dim not in exclude:
56+
new_indexes[dim] = all_indexes[dim][0]
57+
objects[idx + 1] = obj._overwrite_indexes(new_indexes)
58+
59+
return objects
60+
61+
3862
def align(
3963
*objects,
4064
join="inner",
@@ -57,7 +81,7 @@ def align(
5781
----------
5882
*objects : Dataset or DataArray
5983
Objects to align.
60-
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
84+
join : {'outer', 'inner', 'left', 'right', 'exact', 'override'}, optional
6185
Method for joining the indexes of the passed objects along each
6286
dimension:
6387
@@ -67,6 +91,9 @@ def align(
6791
- 'right': use indexes from the last object with each dimension
6892
- 'exact': instead of aligning, raise `ValueError` when indexes to be
6993
aligned are not equal
94+
- 'override': if indexes are of same size, rewrite indexes to be
95+
those of the first object with that dimension. Indexes for the same
96+
dimension must have the same size in all objects.
7097
copy : bool, optional
7198
If ``copy=True``, data in the return values is always copied. If
7299
``copy=False`` and reindexing is unnecessary, or can be performed with
@@ -111,6 +138,9 @@ def align(
111138
else:
112139
all_indexes[dim].append(index)
113140

141+
if join == "override":
142+
objects = _override_indexes(list(objects), all_indexes, exclude)
143+
114144
# We don't reindex over dimensions with all equal indexes for two reasons:
115145
# - It's faster for the usual case (already aligned objects).
116146
# - It ensures it's possible to do operations that don't require alignment

xarray/core/combine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,9 @@ def combine_nested(
379379
- 'right': use indexes from the last object with each dimension
380380
- 'exact': instead of aligning, raise `ValueError` when indexes to be
381381
aligned are not equal
382+
- 'override': if indexes are of same size, rewrite indexes to be
383+
those of the first object with that dimension. Indexes for the same
384+
dimension must have the same size in all objects.
382385
383386
Returns
384387
-------
@@ -529,6 +532,9 @@ def combine_by_coords(
529532
- 'right': use indexes from the last object with each dimension
530533
- 'exact': instead of aligning, raise `ValueError` when indexes to be
531534
aligned are not equal
535+
- 'override': if indexes are of same size, rewrite indexes to be
536+
those of the first object with that dimension. Indexes for the same
537+
dimension must have the same size in all objects.
532538
533539
Returns
534540
-------
@@ -688,6 +694,9 @@ def auto_combine(
688694
- 'right': use indexes from the last object with each dimension
689695
- 'exact': instead of aligning, raise `ValueError` when indexes to be
690696
aligned are not equal
697+
- 'override': if indexes are of same size, rewrite indexes to be
698+
those of the first object with that dimension. Indexes for the same
699+
dimension must have the same size in all objects.
691700
692701
Returns
693702
-------

xarray/core/concat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def concat(
8585
- 'right': use indexes from the last object with each dimension
8686
- 'exact': instead of aligning, raise `ValueError` when indexes to be
8787
aligned are not equal
88+
- 'override': if indexes are of same size, rewrite indexes to be
89+
those of the first object with that dimension. Indexes for the same
90+
dimension must have the same size in all objects.
8891
8992
indexers, mode, concat_over : deprecated
9093

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def _replace_maybe_drop_dims(
375375
)
376376
return self._replace(variable, coords, name)
377377

378-
def _replace_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
378+
def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
379379
if not len(indexes):
380380
return self
381381
coords = self._coords.copy()

xarray/core/merge.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,9 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA):
540540
- 'right': use indexes from the last object with each dimension
541541
- 'exact': instead of aligning, raise `ValueError` when indexes to be
542542
aligned are not equal
543+
- 'override': if indexes are of same size, rewrite indexes to be
544+
those of the first object with that dimension. Indexes for the same
545+
dimension must have the same size in all objects.
543546
fill_value : scalar, optional
544547
Value to use for newly missing values
545548

xarray/tests/test_concat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ def test_concat_join_kwarg(self):
187187
{"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)},
188188
coords={"x": [0, 1], "y": [0.0001]},
189189
)
190+
expected["override"] = Dataset(
191+
{"a": (("x", "y"), np.array([0, 0], ndmin=2).T)},
192+
coords={"x": [0, 1], "y": [0]},
193+
)
190194

191195
with raises_regex(ValueError, "indexes along dimension 'y'"):
192196
actual = concat([ds1, ds2], join="exact", dim="x")
@@ -396,6 +400,10 @@ def test_concat_join_kwarg(self):
396400
{"a": (("x", "y"), np.array([np.nan, 0], ndmin=2).T)},
397401
coords={"x": [0, 1], "y": [0.0001]},
398402
)
403+
expected["override"] = Dataset(
404+
{"a": (("x", "y"), np.array([0, 0], ndmin=2).T)},
405+
coords={"x": [0, 1], "y": [0]},
406+
)
399407

400408
with raises_regex(ValueError, "indexes along dimension 'y'"):
401409
actual = concat([ds1, ds2], join="exact", dim="x")

xarray/tests/test_dataarray.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3144,6 +3144,56 @@ def test_align_copy(self):
31443144
assert_identical(x, x2)
31453145
assert source_ndarray(x2.data) is not source_ndarray(x.data)
31463146

3147+
def test_align_override(self):
3148+
left = DataArray([1, 2, 3], dims="x", coords={"x": [0, 1, 2]})
3149+
right = DataArray(
3150+
np.arange(9).reshape((3, 3)),
3151+
dims=["x", "y"],
3152+
coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]},
3153+
)
3154+
3155+
expected_right = DataArray(
3156+
np.arange(9).reshape(3, 3),
3157+
dims=["x", "y"],
3158+
coords={"x": [0, 1, 2], "y": [1, 2, 3]},
3159+
)
3160+
3161+
new_left, new_right = align(left, right, join="override")
3162+
assert_identical(left, new_left)
3163+
assert_identical(new_right, expected_right)
3164+
3165+
new_left, new_right = align(left, right, exclude="x", join="override")
3166+
assert_identical(left, new_left)
3167+
assert_identical(right, new_right)
3168+
3169+
new_left, new_right = xr.align(
3170+
left.isel(x=0, drop=True), right, exclude="x", join="override"
3171+
)
3172+
assert_identical(left.isel(x=0, drop=True), new_left)
3173+
assert_identical(right, new_right)
3174+
3175+
with raises_regex(ValueError, "Indexes along dimension 'x' don't have"):
3176+
align(left.isel(x=0).expand_dims("x"), right, join="override")
3177+
3178+
@pytest.mark.parametrize(
3179+
"darrays",
3180+
[
3181+
[
3182+
DataArray(0),
3183+
DataArray([1], [("x", [1])]),
3184+
DataArray([2, 3], [("x", [2, 3])]),
3185+
],
3186+
[
3187+
DataArray([2, 3], [("x", [2, 3])]),
3188+
DataArray([1], [("x", [1])]),
3189+
DataArray(0),
3190+
],
3191+
],
3192+
)
3193+
def test_align_override_error(self, darrays):
3194+
with raises_regex(ValueError, "Indexes along dimension 'x' don't have"):
3195+
xr.align(*darrays, join="override")
3196+
31473197
def test_align_exclude(self):
31483198
x = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, -2]), ("b", [3, 4])])
31493199
y = DataArray([[1, 2], [3, 4]], coords=[("a", [-1, 20]), ("b", [5, 6])])

xarray/tests/test_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,28 @@ def test_align_exact(self):
19211921
with raises_regex(ValueError, "indexes .* not equal"):
19221922
xr.align(left, right, join="exact")
19231923

1924+
def test_align_override(self):
1925+
left = xr.Dataset(coords={"x": [0, 1, 2]})
1926+
right = xr.Dataset(coords={"x": [0.1, 1.1, 2.1], "y": [1, 2, 3]})
1927+
expected_right = xr.Dataset(coords={"x": [0, 1, 2], "y": [1, 2, 3]})
1928+
1929+
new_left, new_right = xr.align(left, right, join="override")
1930+
assert_identical(left, new_left)
1931+
assert_identical(new_right, expected_right)
1932+
1933+
new_left, new_right = xr.align(left, right, exclude="x", join="override")
1934+
assert_identical(left, new_left)
1935+
assert_identical(right, new_right)
1936+
1937+
new_left, new_right = xr.align(
1938+
left.isel(x=0, drop=True), right, exclude="x", join="override"
1939+
)
1940+
assert_identical(left.isel(x=0, drop=True), new_left)
1941+
assert_identical(right, new_right)
1942+
1943+
with raises_regex(ValueError, "Indexes along dimension 'x' don't have"):
1944+
xr.align(left.isel(x=0).expand_dims("x"), right, join="override")
1945+
19241946
def test_align_exclude(self):
19251947
x = Dataset(
19261948
{

0 commit comments

Comments
 (0)