Skip to content

Commit 762c961

Browse files
committed
Accept coordinates with MultiIndex (solve issue pydata#3008)
1 parent 788cd60 commit 762c961

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

xarray/core/coordinates.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
cast,
1414
)
1515

16+
import numpy as np
1617
import pandas as pd
1718

1819
from . import formatting, indexing
@@ -106,9 +107,47 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
106107
(dim,) = ordered_dims
107108
return self._data.get_index(dim) # type: ignore
108109
else:
110+
from pandas.core.arrays.categorical import factorize_from_iterable
111+
109112
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
110-
names = list(ordered_dims)
111-
return pd.MultiIndex.from_product(indexes, names=names)
113+
114+
# compute the sizes of the repeat and tile for the cartesian product
115+
# (taken from pandas.core.reshape.util)
116+
lenX = np.fromiter((len(index) for index in indexes), dtype=np.intp)
117+
cumprodX = np.cumproduct(lenX)
118+
119+
if cumprodX[-1] != 0:
120+
# sizes of the repeats
121+
b = cumprodX[-1] / cumprodX
122+
else:
123+
# if any factor is empty, the cartesian product is empty
124+
b = np.zeros_like(cumprodX)
125+
126+
# sizes of the tiles
127+
a = np.roll(cumprodX, 1)
128+
a[0] = 1
129+
130+
# loop over the indexes
131+
# for each MultiIndex or Index compute the cartesian product of the codes
132+
133+
code_list = []
134+
level_list = []
135+
names = []
136+
137+
for i, index in enumerate(indexes):
138+
if isinstance(index, pd.MultiIndex):
139+
codes, levels = index.codes, index.levels
140+
else:
141+
code, level = factorize_from_iterable(index)
142+
codes = [code]
143+
levels = [level]
144+
145+
# compute the cartesian product
146+
code_list += [np.tile(np.repeat(code, b[i]), a[i]) for code in codes]
147+
level_list += levels
148+
names += index.names
149+
150+
return pd.MultiIndex(level_list, code_list, names=names)
112151

113152
def update(self, other: Mapping[Hashable, Any]) -> None:
114153
other_vars = getattr(other, "variables", other)

0 commit comments

Comments
 (0)