|
29 | 29 | from collections import namedtuple |
30 | 30 | from contextlib import contextmanager, nullcontext |
31 | 31 | from dataclasses import dataclass |
| 32 | +from datetime import timedelta |
32 | 33 | from multiprocessing import shared_memory |
33 | 34 | from typing import Any, Callable, Optional, Union |
34 | 35 | from unittest.mock import patch |
@@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool): |
978 | 979 | _ENABLE_CUSTOM_ALL_REDUCE = enable |
979 | 980 |
|
980 | 981 |
|
981 | | -def init_distributed_environment( |
982 | | - world_size: int = -1, |
983 | | - rank: int = -1, |
984 | | - distributed_init_method: str = "env://", |
985 | | - local_rank: int = -1, |
986 | | - backend: str = "nccl", |
987 | | -): |
| 982 | +def init_distributed_environment(world_size: int = -1, |
| 983 | + rank: int = -1, |
| 984 | + distributed_init_method: str = "env://", |
| 985 | + local_rank: int = -1, |
| 986 | + backend: str = "nccl", |
| 987 | + timeout: Optional[timedelta] = None): |
988 | 988 | logger.debug( |
989 | 989 | "world_size=%d rank=%d local_rank=%d " |
990 | 990 | "distributed_init_method=%s backend=%s", world_size, rank, local_rank, |
@@ -1020,7 +1020,8 @@ def init_distributed_environment( |
1020 | 1020 | backend=backend, |
1021 | 1021 | init_method=distributed_init_method, |
1022 | 1022 | world_size=world_size, |
1023 | | - rank=rank) |
| 1023 | + rank=rank, |
| 1024 | + timeout=timeout) |
1024 | 1025 | # set the local rank |
1025 | 1026 | # local_rank is not available in torch ProcessGroup, |
1026 | 1027 | # see https://github.com/pytorch/pytorch/issues/122816 |
|
0 commit comments