Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions recipes/finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
world_size, rank = get_world_size_and_rank()
if kwargs["seed"] is not None:
# Ensure that seed is different per rank (and its dataloader workers)
seed(kwargs["seed"] + rank)

# ---- Initialize distributed process group ---- #
device = init_from_env(device_type=kwargs["device"])
# TODO: only supporting devices specified as "cpu", "cuda", or "cuda:n" currently
Expand All @@ -68,6 +62,14 @@ def recipe(kwargs):
if kwargs["device"] in ("cpu", "cuda")
else kwargs["device"].split(":")[0]
)

# ---- Initialize seed ---- #
# Fetch world size and rank after distributed process group initialization
world_size, rank = get_world_size_and_rank()
if kwargs["seed"] is not None:
# Ensure that seed is different per rank (and its dataloader workers)
seed(kwargs["seed"] + rank)

tokenizer = get_tokenizer(kwargs["tokenizer"], path=kwargs["tokenizer_checkpoint"])
logger(msg=f"Loaded tokenizer from {kwargs['tokenizer_checkpoint']}")

Expand Down
18 changes: 18 additions & 0 deletions tests/torchtune/utils/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchtune.utils.device import _get_device_from_env
from torchtune.utils.env import (
_get_process_group_backend_from_device,
get_world_size_and_rank,
init_from_env,
seed,
)
Expand Down Expand Up @@ -64,6 +65,16 @@ def _test_worker_fn(init_pg_explicit: bool) -> torch.device:
)
return device

@staticmethod
def _test_world_size_with_cpu_device(expected_world_size: int) -> None:
torch.distributed.init_process_group(backend="gloo")
init_from_env(device_type="cpu")
world_size, _ = get_world_size_and_rank()
if world_size != expected_world_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a pytest API to assert this to avoid this boilerplate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried regular assert but the actual, expected values both don't get printed out if it is in this subprocess function. I will leave this in as it is for now following the convention used in this test. I will file a GI to clean this up later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #210

raise AssertionError(
f"Expected different world size: received {world_size}, expected {expected_world_size}"
)

def _test_launch_worker(
self,
num_processes: int,
Expand Down Expand Up @@ -92,6 +103,13 @@ def test_get_process_group_backend_gpu(self) -> None:
pg_backend = _get_process_group_backend_from_device(device)
assert pg_backend == "nccl"

def test_world_size_with_cpu(self) -> None:
desired_world_size = 4
lc = get_pet_launch_config(desired_world_size)
launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)(
desired_world_size
)

def test_seed_range(self) -> None:
"""
Verify that exceptions are raised on input values
Expand Down
2 changes: 1 addition & 1 deletion torchtune/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_world_size_and_rank() -> Tuple[int, int]:
"""
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
return 1, 0
return torch.distributed.world_size(), torch.distributed.get_rank()
return torch.distributed.get_world_size(), torch.distributed.get_rank()


def seed(seed: int, debug_mode: Optional[Union[str, int]] = None) -> None:
Expand Down