Skip to content

Commit 55c8897

Browse files
committed
Fix bug with NaNs in by and method='blockwise'
xref pydata/xarray#9320
1 parent f0ce343 commit 55c8897

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

flox/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,10 +2663,18 @@ def groupby_reduce(
26632663
groups = (groups[0][sorted_idx],)
26642664

26652665
if factorize_early:
2666+
assert len(groups) == 1
2667+
(groups_,) = groups
26662668
# nan group labels are factorized to -1, and preserved
26672669
# now we get rid of them by reindexing
2668-
# This also handles bins with no data
2669-
result = reindex_(result, from_=groups[0], to=expected_, fill_value=fill_value).reshape(
2670+
# First, for "blockwise", we can have -1 repeated in different blocks
2671+
# This breaks the reindexing so remove those first.
2672+
if method == "blockwise" and (mask := groups_ == -1).sum(axis=-1) > 1:
2673+
result = result[..., ~mask]
2674+
groups_ = groups_[..., ~mask]
2675+
2676+
# This reindex also handles bins with no data
2677+
result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape(
26702678
result.shape[:-1] + grp_shape
26712679
)
26722680
groups = final_groups

tests/test_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,3 +1928,17 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
19281928
expected = flox.groupby_scan(array.compute(), by, func=func)
19291929
actual = flox.groupby_scan(array, by, func=func)
19301930
assert_equal(expected, actual)
1931+
1932+
1933+
@requires_dask
1934+
def test_blockwise_nans():
1935+
array = dask.array.ones((1, 10), chunks=2)
1936+
by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
1937+
actual, actual_groups = flox.groupby_reduce(
1938+
array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1939+
)
1940+
expected, expected_groups = flox.groupby_reduce(
1941+
array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1942+
)
1943+
assert_equal(expected_groups, actual_groups)
1944+
assert_equal(expected, actual)

0 commit comments

Comments
 (0)