Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 4 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from typing_extensions import TypeAlias

from packaging.version import parse as parse_version
from tlz import first, groupby, merge, partition_all, valmap
from tlz import first, groupby, merge, valmap

import dask
from dask.base import collections_to_dsk, tokenize
Expand Down Expand Up @@ -2291,34 +2291,9 @@ def map(
)
total_length = sum(len(x) for x in iterables)
if batch_size and batch_size > 1 and total_length > batch_size:
batches = list(
zip(*(partition_all(batch_size, iterable) for iterable in iterables))
)
keys: list[list[Any]] | list[Any]
if isinstance(key, list):
keys = [list(element) for element in partition_all(batch_size, key)]
else:
keys = [key for _ in range(len(batches))]
return sum(
(
self.map(
func,
*batch,
key=key,
workers=workers,
retries=retries,
priority=priority,
allow_other_workers=allow_other_workers,
fifo_timeout=fifo_timeout,
resources=resources,
actor=actor,
actors=actors,
pure=pure,
**kwargs,
)
for key, batch in zip(keys, batches)
),
[],
warnings.warn(
'The argument "batch_size" is ignored and will be removed in a future version.',
DeprecationWarning,
)

key = key or funcname(func)
Expand Down
23 changes: 14 additions & 9 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,20 @@ async def test_map_retries(c, s, a, b):

@gen_cluster(client=True)
async def test_map_batch_size(c, s, a, b):
result = c.map(inc, range(100), batch_size=10)
with pytest.deprecated_call(match="batch_size"):
result = c.map(inc, range(100), batch_size=10)
result = await c.gather(result)
assert result == list(range(1, 101))

result = c.map(add, range(100), range(100), batch_size=10)
with pytest.deprecated_call(match="batch_size"):
result = c.map(add, range(100), range(100), batch_size=10)
result = await c.gather(result)
assert result == list(range(0, 200, 2))

# mismatch shape
result = c.map(add, range(100, 200), range(10), batch_size=2)

with pytest.deprecated_call(match="batch_size"):
result = c.map(add, range(100, 200), range(10), batch_size=2)
result = await c.gather(result)
assert result == list(range(100, 120, 2))

Expand All @@ -261,12 +265,13 @@ async def test_map_batch_size(c, s, a, b):
async def test_custom_key_with_batches(c, s, a, b):
"""Test of <https://github.com/dask/distributed/issues/4588>"""

futs = c.map(
lambda x: x**2,
range(10),
batch_size=5,
key=[str(x) for x in range(10)],
)
with pytest.deprecated_call(match="batch_size"):
futs = c.map(
lambda x: x**2,
range(10),
batch_size=5,
key=[str(x) for x in range(10)],
)
assert len(futs) == 10
await wait(futs)

Expand Down
Loading