diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 20f9fd7ca2f..828712a8d13 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,6 +15,19 @@ What's New .. _whats-new.0.10.0: +v0.10.0 (unreleased) +-------------------- + +Bug fixes +~~~~~~~~~ + +- Fixed unexpected behavior in ``Dataset.set_index()`` and + ``DataArray.set_index()`` introduced by Pandas 0.21.0. Setting a new + index with a single variable resulted in 1-level + ``pandas.MultiIndex`` instead of a simple ``pandas.Index`` + (:issue:`1722`). By `Benoit Bovy `_. + + v0.10.0 rc2 (13 November 2017) ------------------------------ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 56c9df0af93..9c99213dc23 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -136,6 +136,14 @@ def merge_indexes( names, labels, levels = [], [], [] current_index_variable = variables.get(dim) + for n in var_names: + var = variables[n] + if (current_index_variable is not None and + var.dims != current_index_variable.dims): + raise ValueError( + "dimension mismatch between %r %s and %r %s" + % (dim, current_index_variable.dims, n, var.dims)) + if current_index_variable is not None and append: current_index = current_index_variable.to_index() if isinstance(current_index, pd.MultiIndex): @@ -148,20 +156,19 @@ def merge_indexes( labels.append(cat.codes) levels.append(cat.categories) - for n in var_names: - names.append(n) - var = variables[n] - if ((current_index_variable is not None) and - (var.dims != current_index_variable.dims)): - raise ValueError( - "dimension mismatch between %r %s and %r %s" - % (dim, current_index_variable.dims, n, var.dims)) - else: + if not len(names) and len(var_names) == 1: + idx = pd.Index(variables[var_names[0]].values) + + else: + for n in var_names: + names.append(n) + var = variables[n] cat = pd.Categorical(var.values, ordered=True) labels.append(cat.codes) levels.append(cat.categories) - idx = pd.MultiIndex(labels=labels, levels=levels, names=names) + idx = pd.MultiIndex(labels=labels, levels=levels, names=names) + vars_to_replace[dim] = IndexVariable(dim, idx) vars_to_remove.extend(var_names) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 0ffdeb61419..192624b64f3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2007,6 +2007,12 @@ def test_set_index(self): ds.set_index(x=mindex.names, inplace=True) self.assertDatasetIdentical(ds, expected) + # ensure set_index with no existing index and a single data var given + # doesn't return multi-index + ds = Dataset(data_vars={'x_var': ('x', [0, 1, 2])}) + expected = Dataset(coords={'x': [0, 1, 2]}) + self.assertDataArrayIdentical(ds.set_index(x='x_var'), expected) + def test_reset_index(self): ds = create_test_multiindex() mindex = ds['x'].to_index()