Skip to content

Commit 85a788d

Browse files
OddBlokeJelleZijlstra
authored andcommitted
multiprocessing Pool (and context manager) fixes/improvements (#1562)
* Use typing.ContextManager for multiprocessing context managers Prior to this commit, the types for __enter__ and __exit__ were not fully defined; this addresses that. * Move Pool class stub to multiprocessing.pool This is where the class is actually defined in the stdlib. * Ensure that __enter__ on Pool subclasses returns the subclass This ensures that: ```py class MyPool(Pool): def my_method(self): pass with MyPool() as pool: pool.my_method() ``` type-checks correctly. * Update the signature of BaseContext.Pool to match Pool.__init__ * Restore multiprocessing.Pool as a function And also add comments to note that it should have an identical signature to multiprocessing.context.BaseContext.Pool (because it is just that method partially applied).
1 parent 8700993 commit 85a788d

File tree

4 files changed

+45
-67
lines changed

4 files changed

+45
-67
lines changed

stdlib/3/multiprocessing/__init__.pyi

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Stubs for multiprocessing
22

3-
from typing import Any, Callable, Iterable, Mapping, Optional, Dict, List, Union, TypeVar
3+
from typing import (
4+
Any, Callable, ContextManager, Iterable, Mapping, Optional, Dict, List,
5+
Union, TypeVar,
6+
)
47

58
from logging import Logger
9+
from multiprocessing import pool
610
from multiprocessing.context import BaseContext
711
from multiprocessing.managers import SyncManager
812
from multiprocessing.pool import AsyncResult
@@ -12,11 +16,9 @@ import queue
1216

1317
_T = TypeVar('_T')
1418

15-
class Lock():
19+
class Lock(ContextManager[Lock]):
1620
def acquire(self, block: bool = ..., timeout: int = ...) -> None: ...
1721
def release(self) -> None: ...
18-
def __enter__(self) -> 'Lock': ...
19-
def __exit__(self, exc_type, exc_value, tb) -> None: ...
2022

2123
class Event(object):
2224
def __init__(self, *, ctx: BaseContext) -> None: ...
@@ -25,54 +27,13 @@ class Event(object):
2527
def clear(self) -> None: ...
2628
def wait(self, timeout: Optional[int] = ...) -> bool: ...
2729

28-
class Pool():
29-
def __init__(self, processes: Optional[int] = ...,
30-
initializer: Optional[Callable[..., None]] = ...,
31-
initargs: Iterable[Any] = ...,
32-
maxtasksperchild: Optional[int] = ...,
33-
context: Optional[Any] = None) -> None: ...
34-
def apply(self,
35-
func: Callable[..., Any],
36-
args: Iterable[Any] = ...,
37-
kwds: Dict[str, Any]=...) -> Any: ...
38-
def apply_async(self,
39-
func: Callable[..., Any],
40-
args: Iterable[Any] = ...,
41-
kwds: Dict[str, Any] = ...,
42-
callback: Optional[Callable[..., None]] = None,
43-
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
44-
def map(self,
45-
func: Callable[..., Any],
46-
iterable: Iterable[Any] = ...,
47-
chunksize: Optional[int] = ...) -> List[Any]: ...
48-
def map_async(self, func: Callable[..., Any],
49-
iterable: Iterable[Any] = ...,
50-
chunksize: Optional[int] = ...,
51-
callback: Optional[Callable[..., None]] = None,
52-
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
53-
def imap(self,
54-
func: Callable[..., Any],
55-
iterable: Iterable[Any] = ...,
56-
chunksize: Optional[int] = None) -> Iterable[Any]: ...
57-
def imap_unordered(self,
58-
func: Callable[..., Any],
59-
iterable: Iterable[Any] = ...,
60-
chunksize: Optional[int] = None) -> Iterable[Any]: ...
61-
def starmap(self,
62-
func: Callable[..., Any],
63-
iterable: Iterable[Iterable[Any]] = ...,
64-
chunksize: Optional[int] = None) -> List[Any]: ...
65-
def starmap_async(self,
66-
func: Callable[..., Any],
67-
iterable: Iterable[Iterable[Any]] = ...,
68-
chunksize: Optional[int] = ...,
69-
callback: Optional[Callable[..., None]] = None,
70-
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
71-
def close(self) -> None: ...
72-
def terminate(self) -> None: ...
73-
def join(self) -> None: ...
74-
def __enter__(self) -> 'Pool': ...
75-
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
30+
# N.B. This is generated at runtime by partially applying
31+
# multiprocessing.context.BaseContext.Pool, so the two signatures should be
32+
# identical (modulo self).
33+
def Pool(processes: Optional[int] = ...,
34+
initializer: Optional[Callable[..., Any]] = ...,
35+
initargs: Iterable[Any] = ...,
36+
maxtasksperchild: Optional[int] = ...) -> pool.Pool: ...
7637

7738
class Process():
7839
name: str

stdlib/3/multiprocessing/context.pyi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from logging import Logger
44
import multiprocessing
55
import sys
6-
from typing import Any, Callable, Optional, List, Sequence, Tuple, Type, Union
6+
from typing import (
7+
Any, Callable, Iterable, Optional, List, Sequence, Tuple, Type, Union,
8+
)
79

810
class ProcessError(Exception): ...
911

@@ -49,13 +51,16 @@ class BaseContext(object):
4951
def JoinableQueue(self, maxsize: int = ...) -> Any: ...
5052
# TODO: change return to SimpleQueue once a stub exists in multiprocessing.queues
5153
def SimpleQueue(self) -> Any: ...
54+
# N.B. This method is partially applied at runtime to generate
55+
# multiprocessing.Pool, so the two signatures should be identical (modulo
56+
# self).
5257
def Pool(
5358
self,
5459
processes: Optional[int] = ...,
5560
initializer: Optional[Callable[..., Any]] = ...,
56-
initargs: Tuple = ...,
61+
initargs: Iterable[Any] = ...,
5762
maxtasksperchild: Optional[int] = ...
58-
) -> multiprocessing.Pool: ...
63+
) -> multiprocessing.pool.Pool: ...
5964
# TODO: typecode_or_type param is a ctype with a base class of _SimpleCData or array.typecode Need to figure out
6065
# how to handle the ctype
6166
# TODO: change return to RawValue once a stub exists in multiprocessing.sharedctypes

stdlib/3/multiprocessing/managers.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import queue
66
import threading
77
from typing import (
8-
Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, TypeVar
8+
Any, Callable, ContextManager, Dict, Iterable, List, Mapping, Optional,
9+
Sequence, TypeVar,
910
)
1011

1112
_T = TypeVar('_T')
@@ -16,13 +17,11 @@ class Namespace: ...
1617

1718
_Namespace = Namespace
1819

19-
class BaseManager:
20+
class BaseManager(ContextManager[BaseManager]):
2021
def register(self, typeid: str, callable: Any = ...) -> None: ...
2122
def shutdown(self) -> None: ...
2223
def start(self, initializer: Optional[Callable[..., Any]] = ...,
2324
initargs: Iterable[Any] = ...) -> None: ...
24-
def __enter__(self) -> 'BaseManager': ...
25-
def __exit__(self, exc_type, exc_value, tb) -> None: ...
2625

2726
class SyncManager(BaseManager):
2827
def BoundedSemaphore(self, value: Any = ...) -> threading.BoundedSemaphore: ...

stdlib/3/multiprocessing/pool.pyi

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,25 @@
22

33
# NOTE: These are incomplete!
44

5-
from typing import Any, Callable, Iterable, Mapping, Optional, Dict, List
5+
from typing import (
6+
Any, Callable, ContextManager, Iterable, Mapping, Optional, Dict, List,
7+
TypeVar,
8+
)
9+
10+
_T = TypeVar('_T', bound='Pool')
611

712
class AsyncResult():
813
def get(self, timeout: float = ...) -> Any: ...
914
def wait(self, timeout: float = ...) -> None: ...
1015
def ready(self) -> bool: ...
1116
def successful(self) -> bool: ...
1217

13-
class ThreadPool():
14-
def __init__(self, processes: Optional[int] = None,
15-
initializer: Optional[Callable[..., Any]] = None,
16-
initargs: Iterable[Any] = ...) -> None: ...
18+
class Pool(ContextManager[Pool]):
19+
def __init__(self, processes: Optional[int] = ...,
20+
initializer: Optional[Callable[..., None]] = ...,
21+
initargs: Iterable[Any] = ...,
22+
maxtasksperchild: Optional[int] = ...,
23+
context: Optional[Any] = None) -> None: ...
1724
def apply(self,
1825
func: Callable[..., Any],
1926
args: Iterable[Any] = ...,
@@ -30,7 +37,7 @@ class ThreadPool():
3037
chunksize: Optional[int] = None) -> List[Any]: ...
3138
def map_async(self, func: Callable[..., Any],
3239
iterable: Iterable[Any] = ...,
33-
chunksize: Optional[Optional[int]] = None,
40+
chunksize: Optional[int] = None,
3441
callback: Optional[Callable[..., None]] = None,
3542
error_callback: Optional[Callable[[BaseException], None]] = None) -> AsyncResult: ...
3643
def imap(self,
@@ -54,5 +61,11 @@ class ThreadPool():
5461
def close(self) -> None: ...
5562
def terminate(self) -> None: ...
5663
def join(self) -> None: ...
57-
def __enter__(self) -> 'ThreadPool': ...
58-
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
64+
def __enter__(self: _T) -> _T: ...
65+
66+
67+
class ThreadPool(Pool, ContextManager[ThreadPool]):
68+
69+
def __init__(self, processes: Optional[int] = None,
70+
initializer: Optional[Callable[..., Any]] = None,
71+
initargs: Iterable[Any] = ...) -> None: ...

0 commit comments

Comments
 (0)