@@ -1590,18 +1590,36 @@ def test_validate_reindex_map_reduce(
15901590 dask_expected , reindex , func , expected_groups , any_by_dask
15911591) -> None :
15921592 actual = _validate_reindex (
1593- reindex , func , "map-reduce" , expected_groups , any_by_dask , is_dask_array = True
1593+ reindex ,
1594+ func ,
1595+ "map-reduce" ,
1596+ expected_groups ,
1597+ any_by_dask ,
1598+ is_dask_array = True ,
1599+ array_dtype = np .dtype ("int32" ),
15941600 )
15951601 assert actual is dask_expected
15961602
15971603 # always reindex with all numpy inputs
15981604 actual = _validate_reindex (
1599- reindex , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1605+ reindex ,
1606+ func ,
1607+ "map-reduce" ,
1608+ expected_groups ,
1609+ any_by_dask = False ,
1610+ is_dask_array = False ,
1611+ array_dtype = np .dtype ("int32" ),
16001612 )
16011613 assert actual
16021614
16031615 actual = _validate_reindex (
1604- True , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1616+ True ,
1617+ func ,
1618+ "map-reduce" ,
1619+ expected_groups ,
1620+ any_by_dask = False ,
1621+ is_dask_array = False ,
1622+ array_dtype = np .dtype ("int32" ),
16051623 )
16061624 assert actual
16071625
@@ -1611,19 +1629,37 @@ def test_validate_reindex() -> None:
16111629 for method in methods :
16121630 with pytest .raises (NotImplementedError ):
16131631 _validate_reindex (
1614- True , "argmax" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1632+ True ,
1633+ "argmax" ,
1634+ method ,
1635+ expected_groups = None ,
1636+ any_by_dask = False ,
1637+ is_dask_array = True ,
1638+ array_dtype = np .dtype ("int32" ),
16151639 )
16161640
16171641 methods : list [T_Method ] = ["blockwise" , "cohorts" ]
16181642 for method in methods :
16191643 with pytest .raises (ValueError ):
16201644 _validate_reindex (
1621- True , "sum" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1645+ True ,
1646+ "sum" ,
1647+ method ,
1648+ expected_groups = None ,
1649+ any_by_dask = False ,
1650+ is_dask_array = True ,
1651+ array_dtype = np .dtype ("int32" ),
16221652 )
16231653
16241654 for func in ["sum" , "argmax" ]:
16251655 actual = _validate_reindex (
1626- None , func , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1656+ None ,
1657+ func ,
1658+ method ,
1659+ expected_groups = None ,
1660+ any_by_dask = False ,
1661+ is_dask_array = True ,
1662+ array_dtype = np .dtype ("int32" ),
16271663 )
16281664 assert actual is False
16291665
@@ -1635,6 +1671,7 @@ def test_validate_reindex() -> None:
16351671 expected_groups = np .array ([1 , 2 , 3 ]),
16361672 any_by_dask = False ,
16371673 is_dask_array = True ,
1674+ array_dtype = np .dtype ("int32" ),
16381675 )
16391676
16401677 assert _validate_reindex (
@@ -1644,6 +1681,7 @@ def test_validate_reindex() -> None:
16441681 expected_groups = np .array ([1 , 2 , 3 ]),
16451682 any_by_dask = True ,
16461683 is_dask_array = True ,
1684+ array_dtype = np .dtype ("int32" ),
16471685 )
16481686 assert _validate_reindex (
16491687 None ,
@@ -1652,8 +1690,24 @@ def test_validate_reindex() -> None:
16521690 expected_groups = np .array ([1 , 2 , 3 ]),
16531691 any_by_dask = True ,
16541692 is_dask_array = True ,
1693+ array_dtype = np .dtype ("int32" ),
16551694 )
16561695
1696+ kwargs = dict (
1697+ method = "blockwise" ,
1698+ expected_groups = np .array ([1 , 2 , 3 ]),
1699+ any_by_dask = True ,
1700+ is_dask_array = True ,
1701+ )
1702+
1703+ for func in ["nanfirst" , "nanlast" ]:
1704+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("int32" ), ** kwargs ) # type: ignore
1705+ assert _validate_reindex (None , func , array_dtype = np .dtype ("float32" ), ** kwargs ) # type: ignore
1706+
1707+ for func in ["first" , "last" ]:
1708+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("int32" ), ** kwargs ) # type: ignore
1709+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("float32" ), ** kwargs ) # type: ignore
1710+
16571711
16581712@requires_dask
16591713def test_1d_blockwise_sort_optimization ():
0 commit comments