Skip to content

Commit 489c843

Browse files
committed
Merge branch 'main' into topk
* main: Fix upstream-dev tests (#421) Bump codecov/codecov-action from 5.1.2 to 5.3.1 (#420) optimize cohorts yet again (#419)
2 parents a5bcc5b + 109fe42 commit 489c843

File tree

7 files changed

+40
-28
lines changed

7 files changed

+40
-28
lines changed

.github/workflows/ci-additional.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ jobs:
7777
--ignore flox/tests \
7878
--cov=./ --cov-report=xml
7979
- name: Upload code coverage to Codecov
80-
uses: codecov/codecov-action@v5.1.2
80+
uses: codecov/codecov-action@v5.3.1
8181
with:
8282
file: ./coverage.xml
8383
flags: unittests
@@ -132,7 +132,7 @@ jobs:
132132
python -m mypy --install-types --non-interactive --cache-dir=.mypy_cache/ --cobertura-xml-report mypy_report
133133
134134
- name: Upload mypy coverage to Codecov
135-
uses: codecov/codecov-action@v5.1.2
135+
uses: codecov/codecov-action@v5.3.1
136136
with:
137137
file: mypy_report/cobertura.xml
138138
flags: mypy

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676
python -c "import xarray; xarray.show_versions()"
7777
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox
7878
- name: Upload code coverage to Codecov
79-
uses: codecov/codecov-action@v5.1.2
79+
uses: codecov/codecov-action@v5.3.1
8080
with:
8181
file: ./coverage.xml
8282
flags: unittests

.github/workflows/upstream-dev-ci.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ jobs:
9393
id: status
9494
run: |
9595
pytest -rf -n auto --cov=./ --cov-report=xml \
96-
--report-log output-${{ matrix.python-version }}-log.jsonl
96+
--report-log output-${{ matrix.python-version }}-log.jsonl \
97+
--hypothesis-profile ci
9798
- name: Generate and publish the report
9899
if: |
99100
failure()

ci/upstream-dev-env.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ name: flox-tests
22
channels:
33
- conda-forge
44
dependencies:
5+
- asv_runner # for test_asv
56
- cachey
67
- codecov
78
- pooch
9+
- hypothesis
810
- toolz
911
# - numpy
1012
# - pandas

flox/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ def _reduce_blockwise(
14701470
return result
14711471

14721472

1473-
def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
1473+
def _normalize_indexes(ndim: int, flatblocks: Sequence[int], blkshape: tuple[int, ...]) -> tuple:
14741474
"""
14751475
.blocks accessor can only accept one iterable at a time,
14761476
but can handle multiple slices.
@@ -1488,20 +1488,23 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
14881488
if i.ndim == 0:
14891489
normalized.append(i.item())
14901490
else:
1491-
if np.array_equal(i, np.arange(blkshape[ax])):
1491+
if len(i) == blkshape[ax] and np.array_equal(i, np.arange(blkshape[ax])):
14921492
normalized.append(slice(None))
1493-
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
1494-
normalized.append(slice(i[0], i[-1] + 1))
1493+
elif _issorted(i) and np.array_equal(i, np.arange(i[0], i[-1] + 1)):
1494+
start = None if i[0] == 0 else i[0]
1495+
stop = i[-1] + 1
1496+
stop = None if stop == blkshape[ax] else stop
1497+
normalized.append(slice(start, stop))
14951498
else:
14961499
normalized.append(list(i))
1497-
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)
1500+
full_normalized = (slice(None),) * (ndim - len(normalized)) + tuple(normalized)
14981501

14991502
# has no iterables
15001503
noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
15011504
# has all iterables
15021505
alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")}
15031506

1504-
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values())))
1507+
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values()))) # type: ignore[arg-type, var-annotated]
15051508

15061509
full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter))
15071510

@@ -1528,7 +1531,6 @@ def subset_to_blocks(
15281531
-------
15291532
dask.array
15301533
"""
1531-
from dask.array.slicing import normalize_index
15321534
from dask.base import tokenize
15331535

15341536
if blkshape is None:
@@ -1537,10 +1539,9 @@ def subset_to_blocks(
15371539
if chunks_as_array is None:
15381540
chunks_as_array = tuple(np.array(c) for c in array.chunks)
15391541

1540-
index = _normalize_indexes(array, flatblocks, blkshape)
1542+
index = _normalize_indexes(array.ndim, flatblocks, blkshape)
15411543

15421544
# These rest is copied from dask.array.core.py with slight modifications
1543-
index = normalize_index(index, array.numblocks)
15441545
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
15451546

15461547
name = "groupby-cohort-" + tokenize(array, index)

flox/dask_array_ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import builtins
22
import math
3-
from functools import partial
3+
from functools import lru_cache, partial
44
from itertools import product
55
from numbers import Integral
66

@@ -84,14 +84,8 @@ def partial_reduce(
8484
axis: tuple[int, ...],
8585
block_index: int | None = None,
8686
):
87-
numblocks = tuple(len(c) for c in chunks)
88-
ndim = len(numblocks)
89-
parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n) in enumerate(numblocks)]
90-
keys = product(*map(range, map(len, parts)))
91-
out_chunks = [
92-
tuple(1 for p in partition_all(split_every[i], c)) if i in split_every else c
93-
for (i, c) in enumerate(chunks)
94-
]
87+
ndim = len(chunks)
88+
keys, parts, out_chunks = get_parts(tuple(split_every.items()), chunks)
9589
for k, p in zip(keys, product(*parts)):
9690
free = {i: j[0] for (i, j) in enumerate(p) if len(j) == 1 and i not in split_every}
9791
dummy = dict(i for i in enumerate(p) if i[0] in split_every)
@@ -101,3 +95,17 @@ def partial_reduce(
10195
k = (*k[:-1], block_index)
10296
dsk[(name,) + k] = (func, g)
10397
return dsk, out_chunks
98+
99+
100+
@lru_cache
101+
def get_parts(split_every_items, chunks):
102+
numblocks = tuple(len(c) for c in chunks)
103+
split_every = dict(split_every_items)
104+
105+
parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n) in enumerate(numblocks)]
106+
keys = tuple(product(*map(range, map(len, parts))))
107+
out_chunks = tuple(
108+
tuple(1 for p in partition_all(split_every[i], c)) if i in split_every else c
109+
for (i, c) in enumerate(chunks)
110+
)
111+
return keys, parts, out_chunks

tests/test_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from collections.abc import Callable
77
from functools import partial, reduce
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99
from unittest.mock import MagicMock, patch
1010

1111
import numpy as np
@@ -1551,7 +1551,7 @@ def test_normalize_block_indexing_1d(flatblocks, expected):
15511551
nblocks = 5
15521552
array = dask.array.ones((nblocks,), chunks=(1,))
15531553
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1554-
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1554+
actual = _normalize_indexes(array.ndim, flatblocks, array.blocks.shape)
15551555
assert_equal_tuple(expected, actual)
15561556

15571557

@@ -1563,17 +1563,17 @@ def test_normalize_block_indexing_1d(flatblocks, expected):
15631563
((1, 2, 3), (0, slice(1, 4))),
15641564
((1, 3), (0, [1, 3])),
15651565
((0, 1, 3), (0, [0, 1, 3])),
1566-
(tuple(range(10)), (slice(0, 2), slice(None))),
1567-
((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])),
1566+
(tuple(range(10)), (slice(None, 2), slice(None))),
1567+
((0, 1, 3, 5, 6, 8), (slice(None, 2), [0, 1, 3])),
15681568
((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])),
15691569
),
15701570
)
1571-
def test_normalize_block_indexing_2d(flatblocks, expected):
1571+
def test_normalize_block_indexing_2d(flatblocks: tuple[int, ...], expected: tuple[Any, ...]) -> None:
15721572
nblocks = 5
15731573
ndim = 2
15741574
array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim)
15751575
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1576-
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1576+
actual = _normalize_indexes(array.ndim, flatblocks, array.blocks.shape)
15771577
assert_equal_tuple(expected, actual)
15781578

15791579

0 commit comments

Comments
 (0)