@@ -1797,12 +1797,13 @@ def dask_groupby_agg(
17971797 output_chunks = new_dims_shape + reduced .chunks [: - len (axis )] + group_chunks
17981798 new_axes = dict (zip (new_inds , new_dims_shape ))
17991799
1800- if method == "blockwise" and len (axis ) > 1 :
1801- # The final results are available but the blocks along axes
1802- # need to be reshaped to axis=-1
1803- # I don't know that this is possible with blockwise
1804- # All other code paths benefit from an unmaterialized Blockwise layer
1805- reduced = _collapse_blocks_along_axes (reduced , axis , group_chunks )
1800+ if method == "blockwise" :
1801+ if len (axis ) > 1 :
1802+ # The final results are available but the blocks along axes
1803+ # need to be reshaped to axis=-1
1804+ # I don't know that this is possible with blockwise
1805+ # All other code paths benefit from an unmaterialized Blockwise layer
1806+ reduced = _collapse_blocks_along_axes (reduced , axis , group_chunks )
18061807
18071808 # Can't use map_blocks because it forces concatenate=True along drop_axes,
18081809 result = dask .array .blockwise (
@@ -1817,7 +1818,6 @@ def dask_groupby_agg(
18171818 concatenate = False ,
18181819 new_axes = new_axes ,
18191820 )
1820-
18211821 return (result , groups )
18221822
18231823
@@ -2663,10 +2663,18 @@ def groupby_reduce(
26632663 groups = (groups [0 ][sorted_idx ],)
26642664
26652665 if factorize_early :
2666+ assert len (groups ) == 1
2667+ (groups_ ,) = groups
26662668 # nan group labels are factorized to -1, and preserved
26672669 # now we get rid of them by reindexing
2668- # This also handles bins with no data
2669- result = reindex_ (result , from_ = groups [0 ], to = expected_ , fill_value = fill_value ).reshape (
2670+ # First, for "blockwise", we can have -1 repeated in different blocks
2671+ # This breaks the reindexing so remove those first.
2672+ if method == "blockwise" and (mask := groups_ == - 1 ).sum (axis = - 1 ) > 1 :
2673+ result = result [..., ~ mask ]
2674+ groups_ = groups_ [..., ~ mask ]
2675+
2676+ # This reindex also handles bins with no data
2677+ result = reindex_ (result , from_ = groups_ , to = expected_ , fill_value = fill_value ).reshape (
26702678 result .shape [:- 1 ] + grp_shape
26712679 )
26722680 groups = final_groups
0 commit comments