Skip to content

Commit b0db8a8

Browse files
committed
added unstack in apply
fix dataset test bug fixed apply
1 parent 807001e commit b0db8a8

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

xarray/core/groupby.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,31 +101,22 @@ def __init__(self, obj, group, squeeze=False, grouper=None):
101101
"""
102102
from .dataset import as_dataset
103103

104+
if getattr(group, 'name', None) is None:
105+
raise ValueError('`group` must have a name')
106+
self._stacked_dim = None
104107
if group.ndim != 1:
105108
# try to stack the dims of the group into a single dim
106109
# TODO: figure out how to exclude dimensions from the stacking
107110
# (e.g. group over space dims but leave time dim intact)
108111
orig_dims = group.dims
109112
stacked_dim_name = 'stacked_' + '_'.join(orig_dims)
110-
# the copy is necessary here
113+
# the copy is necessary here, otherwise read only array raises error
114+
# in pandas: https://github.com/pydata/pandas/issues/12813
115+
# Is there a performance penalty for calling copy?
111116
group = group.stack(**{stacked_dim_name: orig_dims}).copy()
112-
# without it, an error is raised deep in pandas
113-
########################
114-
# xarray/core/groupby.py
115-
# ---> 31 inverse, values = pd.factorize(ar, sort=True)
116-
# pandas/core/algorithms.pyc in factorize(values, sort, order, na_sentinel, size_hint)
117-
# --> 196 labels = table.get_labels(vals, uniques, 0, na_sentinel, True)
118-
# pandas/hashtable.pyx in pandas.hashtable.Float64HashTable.get_labels (pandas/hashtable.c:10302)()
119-
# pandas/hashtable.so in View.MemoryView.memoryview_cwrapper (pandas/hashtable.c:29882)()
120-
# pandas/hashtable.so in View.MemoryView.memoryview.__cinit__ (pandas/hashtable.c:26251)()
121-
# ValueError: buffer source array is read-only
122-
#######################
123-
# seems related to
124-
# https://github.com/pydata/pandas/issues/10043
125-
# https://github.com/pydata/pandas/pull/10070
126117
obj = obj.stack(**{stacked_dim_name: orig_dims})
127-
if getattr(group, 'name', None) is None:
128-
raise ValueError('`group` must have a name')
118+
self._stacked_dim = stacked_dim_name
119+
self._unstacked_dims = orig_dims
129120
if not hasattr(group, 'dims'):
130121
raise ValueError("`group` must have a 'dims' attribute")
131122
group_dim, = group.dims
@@ -249,6 +240,13 @@ def _maybe_restore_empty_groups(self, combined):
249240
combined = combined.reindex(**indexers)
250241
return combined
251242

243+
def _maybe_unstack_array(self, arr):
244+
"""This gets called if we are applying on an array with a
245+
multidimensional group."""
246+
if self._stacked_dim is not None and self._stacked_dim in arr.dims:
247+
arr = arr.unstack(self._stacked_dim)
248+
return arr
249+
252250
def fillna(self, value):
253251
"""Fill missing values in this object by group.
254252
@@ -358,6 +356,11 @@ def lookup_order(dimension):
358356
new_order = sorted(stacked.dims, key=lookup_order)
359357
return stacked.transpose(*new_order)
360358

359+
def _restore_multiindex(self, combined):
360+
if self._stacked_dim is not None and self._stacked_dim in combined.dims:
361+
combined[self._stacked_dim] = self.group[self._stacked_dim]
362+
return combined
363+
361364
def apply(self, func, shortcut=False, **kwargs):
362365
"""Apply a function over each array in the group and concatenate them
363366
together into a new array.
@@ -399,23 +402,23 @@ def apply(self, func, shortcut=False, **kwargs):
399402
grouped = self._iter_grouped_shortcut()
400403
else:
401404
grouped = self._iter_grouped()
402-
applied = (maybe_wrap_array(arr, func(arr, **kwargs)) for arr in grouped)
405+
applied = (maybe_wrap_array(arr,func(arr, **kwargs)) for arr in grouped)
403406
combined = self._concat(applied, shortcut=shortcut)
404-
result = self._maybe_restore_empty_groups(combined)
407+
result = self._maybe_restore_empty_groups(
408+
self._maybe_unstack_array(combined))
405409
return result
406410

407411
def _concat(self, applied, shortcut=False):
408412
# peek at applied to determine which coordinate to stack over
409413
applied_example, applied = peek_at(applied)
410414
concat_dim, positions = self._infer_concat_args(applied_example)
411-
412415
if shortcut:
413416
combined = self._concat_shortcut(applied, concat_dim, positions)
414417
else:
415418
combined = concat(applied, concat_dim, positions=positions)
416-
417419
if isinstance(combined, type(self.obj)):
418420
combined = self._restore_dim_order(combined)
421+
combined = self._restore_multiindex(combined)
419422
return combined
420423

421424
def reduce(self, func, dim=None, axis=None, keep_attrs=False,

xarray/test/test_dataarray.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,17 @@ def test_groupby_multidim(self):
12551255
actual_sum = array.groupby(dim).sum()
12561256
self.assertDataArrayIdentical(expected_sum, actual_sum)
12571257

1258+
def test_groupby_multidim_apply(self):
1259+
array = DataArray([[0,1],[2,3]],
1260+
coords={'lon': (['ny','nx'], [[30,40],[40,50]] ),
1261+
'lat': (['ny','nx'], [[10,10],[20,20]] ),},
1262+
dims=['ny','nx'])
1263+
actual = array.groupby('lon').apply(
1264+
lambda x : x - x.mean(), shortcut=False)
1265+
expected = DataArray([[0.,-0.5],[0.5,0.]],
1266+
coords=array.coords, dims=array.dims)
1267+
self.assertDataArrayIdentical(expected, actual)
1268+
12581269
def make_rolling_example_array(self):
12591270
times = pd.date_range('2000-01-01', freq='1D', periods=21)
12601271
values = np.random.random((21, 4))

0 commit comments

Comments
 (0)