diff --git a/recipes/finetune_llm.py b/recipes/finetune_llm.py index 285b476fa8..e0e6feeea6 100644 --- a/recipes/finetune_llm.py +++ b/recipes/finetune_llm.py @@ -54,12 +54,6 @@ def recipe(kwargs): # ---- Initialize components ---- # logger = get_logger() - # ---- Initialize seed ---- # - world_size, rank = get_world_size_and_rank() - if kwargs["seed"] is not None: - # Ensure that seed is different per rank (and its dataloader workers) - seed(kwargs["seed"] + rank) - # ---- Initialize distributed process group ---- # device = init_from_env(device_type=kwargs["device"]) # TODO: only supporting devices specified as "cpu", "cuda", or "cuda:n" currently @@ -68,6 +62,14 @@ def recipe(kwargs): if kwargs["device"] in ("cpu", "cuda") else kwargs["device"].split(":")[0] ) + + # ---- Initialize seed ---- # + # Fetch world size and rank after distributed process group initialization + world_size, rank = get_world_size_and_rank() + if kwargs["seed"] is not None: + # Ensure that seed is different per rank (and its dataloader workers) + seed(kwargs["seed"] + rank) + tokenizer = get_tokenizer(kwargs["tokenizer"], path=kwargs["tokenizer_checkpoint"]) logger(msg=f"Loaded tokenizer from {kwargs['tokenizer_checkpoint']}") diff --git a/tests/torchtune/utils/test_env.py b/tests/torchtune/utils/test_env.py index d2e6e023ea..4170de369d 100644 --- a/tests/torchtune/utils/test_env.py +++ b/tests/torchtune/utils/test_env.py @@ -15,6 +15,7 @@ from torchtune.utils.device import _get_device_from_env from torchtune.utils.env import ( _get_process_group_backend_from_device, + get_world_size_and_rank, init_from_env, seed, ) @@ -64,6 +65,16 @@ def _test_worker_fn(init_pg_explicit: bool) -> torch.device: ) return device + @staticmethod + def _test_world_size_with_cpu_device(expected_world_size: int) -> None: + torch.distributed.init_process_group(backend="gloo") + init_from_env(device_type="cpu") + world_size, _ = get_world_size_and_rank() + if world_size != expected_world_size: + raise AssertionError( + f"Expected different world size: received {world_size}, expected {expected_world_size}" + ) + def _test_launch_worker( self, num_processes: int, @@ -92,6 +103,13 @@ def test_get_process_group_backend_gpu(self) -> None: pg_backend = _get_process_group_backend_from_device(device) assert pg_backend == "nccl" + def test_world_size_with_cpu(self) -> None: + desired_world_size = 4 + lc = get_pet_launch_config(desired_world_size) + launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)( + desired_world_size + ) + def test_seed_range(self) -> None: """ Verify that exceptions are raised on input values diff --git a/torchtune/utils/env.py b/torchtune/utils/env.py index d567e21afe..525527039f 100644 --- a/torchtune/utils/env.py +++ b/torchtune/utils/env.py @@ -119,7 +119,7 @@ def get_world_size_and_rank() -> Tuple[int, int]: """ if not torch.distributed.is_available() or not torch.distributed.is_initialized(): return 1, 0 - return torch.distributed.world_size(), torch.distributed.get_rank() + return torch.distributed.get_world_size(), torch.distributed.get_rank() def seed(seed: int, debug_mode: Optional[Union[str, int]] = None) -> None: