Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:


def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
if isinstance(func, Aggregation):
func = func.name
return func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
Expand Down Expand Up @@ -1642,7 +1644,12 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
do_grouped_combine = (
_is_arg_reduction(agg)
or labels_are_unknown
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
)
do_simple_combine = not do_grouped_combine

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
Expand Down Expand Up @@ -1986,8 +1993,13 @@ def _validate_reindex(
expected_groups,
any_by_dask: bool,
is_dask_array: bool,
array_dtype: Any,
) -> bool | None:
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
def first_or_last():
return func in ["first", "last"] or (
_is_first_last_reduction(func) and array_dtype.kind != "f"
)

all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
Expand All @@ -1997,7 +2009,7 @@ def _validate_reindex(
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
if func in ["first", "last"]:
if first_or_last():
raise ValueError("reindex must be None or False when func is 'first' or 'last.")

if reindex is None:
Expand All @@ -2008,9 +2020,10 @@ def _validate_reindex(
if all_numpy:
return True

if func in ["first", "last"]:
if first_or_last():
# have to do the grouped_combine since there's no good fill_value
reindex = False
# Also needed for nanfirst, nanlast with no-NaN dtypes
return False

if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
Expand Down Expand Up @@ -2413,7 +2426,13 @@ def groupby_reduce(
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
reindex,
func,
method,
expected_groups,
any_by_dask,
is_duck_dask_array(array),
array.dtype,
)

if not is_duck_array(array):
Expand Down Expand Up @@ -2601,7 +2620,7 @@ def groupby_reduce(

# TODO: clean this up
reindex = _validate_reindex(
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
)

if TYPE_CHECKING:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
)


@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]])
@pytest.mark.parametrize(
"func",
[
# "first", "last",
"nanfirst",
"nanlast",
],
)
@pytest.mark.parametrize(
"chunks",
[
None,
pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
],
)
def test_first_last_useless(func, chunks, group_idx):
array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8)
if chunks is not None:
array = dask.array.from_array(array, chunks=chunks)
actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy")
expected = np.array([[0, 0], [0, 0]], dtype=np.int8)
assert_equal(actual, expected)


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
@pytest.mark.parametrize("axis", [(0, 1)])
def test_first_last_disallowed(axis, func):
Expand Down