Skip to content

Commit 5eee684

Browse files
committed
Fix
1 parent 5484e4a commit 5eee684

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

flox/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,9 @@ def groupby_reduce(
24252425
if method == "cohorts" and any_by_dask:
24262426
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
24272427

2428+
if not is_duck_array(array):
2429+
array = np.asarray(array)
2430+
24282431
reindex = _validate_reindex(
24292432
reindex,
24302433
func,
@@ -2435,8 +2438,6 @@ def groupby_reduce(
24352438
array.dtype,
24362439
)
24372440

2438-
if not is_duck_array(array):
2439-
array = np.asarray(array)
24402441
is_bool_array = np.issubdtype(array.dtype, bool)
24412442
array = array.astype(np.intp) if is_bool_array else array
24422443

tests/test_core.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,18 +1590,36 @@ def test_validate_reindex_map_reduce(
15901590
dask_expected, reindex, func, expected_groups, any_by_dask
15911591
) -> None:
15921592
actual = _validate_reindex(
1593-
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
1593+
reindex,
1594+
func,
1595+
"map-reduce",
1596+
expected_groups,
1597+
any_by_dask,
1598+
is_dask_array=True,
1599+
array_dtype=np.dtype("int32"),
15941600
)
15951601
assert actual is dask_expected
15961602

15971603
# always reindex with all numpy inputs
15981604
actual = _validate_reindex(
1599-
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1605+
reindex,
1606+
func,
1607+
"map-reduce",
1608+
expected_groups,
1609+
any_by_dask=False,
1610+
is_dask_array=False,
1611+
array_dtype=np.dtype("int32"),
16001612
)
16011613
assert actual
16021614

16031615
actual = _validate_reindex(
1604-
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1616+
True,
1617+
func,
1618+
"map-reduce",
1619+
expected_groups,
1620+
any_by_dask=False,
1621+
is_dask_array=False,
1622+
array_dtype=np.dtype("int32"),
16051623
)
16061624
assert actual
16071625

@@ -1611,19 +1629,37 @@ def test_validate_reindex() -> None:
16111629
for method in methods:
16121630
with pytest.raises(NotImplementedError):
16131631
_validate_reindex(
1614-
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1632+
True,
1633+
"argmax",
1634+
method,
1635+
expected_groups=None,
1636+
any_by_dask=False,
1637+
is_dask_array=True,
1638+
array_dtype=np.dtype("int32"),
16151639
)
16161640

16171641
methods: list[T_Method] = ["blockwise", "cohorts"]
16181642
for method in methods:
16191643
with pytest.raises(ValueError):
16201644
_validate_reindex(
1621-
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1645+
True,
1646+
"sum",
1647+
method,
1648+
expected_groups=None,
1649+
any_by_dask=False,
1650+
is_dask_array=True,
1651+
array_dtype=np.dtype("int32"),
16221652
)
16231653

16241654
for func in ["sum", "argmax"]:
16251655
actual = _validate_reindex(
1626-
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
1656+
None,
1657+
func,
1658+
method,
1659+
expected_groups=None,
1660+
any_by_dask=False,
1661+
is_dask_array=True,
1662+
array_dtype=np.dtype("int32"),
16271663
)
16281664
assert actual is False
16291665

@@ -1635,6 +1671,7 @@ def test_validate_reindex() -> None:
16351671
expected_groups=np.array([1, 2, 3]),
16361672
any_by_dask=False,
16371673
is_dask_array=True,
1674+
array_dtype=np.dtype("int32"),
16381675
)
16391676

16401677
assert _validate_reindex(
@@ -1644,6 +1681,7 @@ def test_validate_reindex() -> None:
16441681
expected_groups=np.array([1, 2, 3]),
16451682
any_by_dask=True,
16461683
is_dask_array=True,
1684+
array_dtype=np.dtype("int32"),
16471685
)
16481686
assert _validate_reindex(
16491687
None,
@@ -1652,8 +1690,24 @@ def test_validate_reindex() -> None:
16521690
expected_groups=np.array([1, 2, 3]),
16531691
any_by_dask=True,
16541692
is_dask_array=True,
1693+
array_dtype=np.dtype("int32"),
16551694
)
16561695

1696+
kwargs = dict(
1697+
method="blockwise",
1698+
expected_groups=np.array([1, 2, 3]),
1699+
any_by_dask=True,
1700+
is_dask_array=True,
1701+
)
1702+
1703+
for func in ["nanfirst", "nanlast"]:
1704+
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore
1705+
assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore
1706+
1707+
for func in ["first", "last"]:
1708+
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore
1709+
assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore
1710+
16571711

16581712
@requires_dask
16591713
def test_1d_blockwise_sort_optimization():

0 commit comments

Comments
 (0)