From 090cb3d5590103e4be2e69383ede9d2d408010f8 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Fri, 12 Jan 2024 21:02:54 -0800 Subject: [PATCH 1/7] Fetch world size, rank after distributed setup --- recipes/finetune_llm.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/recipes/finetune_llm.py b/recipes/finetune_llm.py index 285b476fa8..6cf5197505 100644 --- a/recipes/finetune_llm.py +++ b/recipes/finetune_llm.py @@ -54,12 +54,6 @@ def recipe(kwargs): # ---- Initialize components ---- # logger = get_logger() - # ---- Initialize seed ---- # - world_size, rank = get_world_size_and_rank() - if kwargs["seed"] is not None: - # Ensure that seed is different per rank (and its dataloader workers) - seed(kwargs["seed"] + rank) - # ---- Initialize distributed process group ---- # device = init_from_env(device_type=kwargs["device"]) # TODO: only supporting devices specified as "cpu", "cuda", or "cuda:n" currently @@ -68,6 +62,13 @@ def recipe(kwargs): if kwargs["device"] in ("cpu", "cuda") else kwargs["device"].split(":")[0] ) + + # ---- Initialize seed ---- # + world_size, rank = get_world_size_and_rank() + if kwargs["seed"] is not None: + # Ensure that seed is different per rank (and its dataloader workers) + seed(kwargs["seed"] + rank) + tokenizer = get_tokenizer(kwargs["tokenizer"], path=kwargs["tokenizer_checkpoint"]) logger(msg=f"Loaded tokenizer from {kwargs['tokenizer_checkpoint']}") From ee9f3bc9a8b721438bdb53dba2f888825ec5e867 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Sat, 13 Jan 2024 20:19:00 -0800 Subject: [PATCH 2/7] Address PR comment --- torchtune/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/utils/env.py b/torchtune/utils/env.py index d567e21afe..525527039f 100644 --- a/torchtune/utils/env.py +++ b/torchtune/utils/env.py @@ -119,7 +119,7 @@ def get_world_size_and_rank() -> Tuple[int, int]: """ if not torch.distributed.is_available() or not torch.distributed.is_initialized(): return 1, 0 - return torch.distributed.world_size(), torch.distributed.get_rank() + return torch.distributed.get_world_size(), torch.distributed.get_rank() def seed(seed: int, debug_mode: Optional[Union[str, int]] = None) -> None: From 74217f915d52aaec01a307fe0e4aad69819bba9f Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Wed, 17 Jan 2024 04:37:54 -0800 Subject: [PATCH 3/7] test changes --- tests/torchtune/utils/test_env.py | 4 ++++ torchtune/utils/device.py | 1 + 2 files changed, 5 insertions(+) diff --git a/tests/torchtune/utils/test_env.py b/tests/torchtune/utils/test_env.py index d2e6e023ea..75b97587d9 100644 --- a/tests/torchtune/utils/test_env.py +++ b/tests/torchtune/utils/test_env.py @@ -15,6 +15,7 @@ from torchtune.utils.device import _get_device_from_env from torchtune.utils.env import ( _get_process_group_backend_from_device, + get_world_size_and_rank, init_from_env, seed, ) @@ -62,6 +63,9 @@ def _test_worker_fn(init_pg_explicit: bool) -> torch.device: raise AssertionError( f"Expected different process group backend: received {pg_backend}, expected {expected_pg_backend}" ) + world_size, rank = get_world_size_and_rank() + if world_size != 2: + raise AssertionError(f"Expected world size of 2, received {world_size}") return device def _test_launch_worker( diff --git a/torchtune/utils/device.py b/torchtune/utils/device.py index 9f8d0b9bcc..a9adf4ce3a 100644 --- a/torchtune/utils/device.py +++ b/torchtune/utils/device.py @@ -23,6 +23,7 @@ def _get_device_from_env() -> torch.device: Returns: device """ + return torch.device("cpu") if torch.cuda.is_available(): local_rank = int(os.environ.get("LOCAL_RANK", "0")) if local_rank >= torch.cuda.device_count(): From 35403e05803227bd680e8c641bacfa5f0bb5d1bd Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Wed, 17 Jan 2024 07:51:34 -0800 Subject: [PATCH 4/7] Add unit test --- tests/torchtune/utils/test_env.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/torchtune/utils/test_env.py b/tests/torchtune/utils/test_env.py index 75b97587d9..c6d6c1a5a3 100644 --- a/tests/torchtune/utils/test_env.py +++ b/tests/torchtune/utils/test_env.py @@ -68,6 +68,16 @@ def _test_worker_fn(init_pg_explicit: bool) -> torch.device: raise AssertionError(f"Expected world size of 2, received {world_size}") return device + @staticmethod + def _test_world_size_with_cpu_device(expected_world_size: int) -> None: + torch.distributed.init_process_group(backend="gloo") + init_from_env(device_type="cpu") + world_size, _ = get_world_size_and_rank() + if world_size != expected_world_size: + raise AssertionError( + f"Expected different world size: received {world_size}, expected {expected_world_size}" + ) + def _test_launch_worker( self, num_processes: int, @@ -96,6 +106,13 @@ def test_get_process_group_backend_gpu(self) -> None: pg_backend = _get_process_group_backend_from_device(device) assert pg_backend == "nccl" + def test_world_size_with_cpu(self) -> None: + desired_world_size = 4 + lc = get_pet_launch_config(desired_world_size) + launcher.elastic_launch(lc, entrypoint=self._test_world_size_with_cpu_device)( + desired_world_size + ) + def test_seed_range(self) -> None: """ Verify that exceptions are raised on input values From fc99effca450c8c8f8542c1fe612e05d19ad44c8 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Wed, 17 Jan 2024 07:52:52 -0800 Subject: [PATCH 5/7] Revert unintended changes --- tests/torchtune/utils/test_env.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/torchtune/utils/test_env.py b/tests/torchtune/utils/test_env.py index c6d6c1a5a3..4170de369d 100644 --- a/tests/torchtune/utils/test_env.py +++ b/tests/torchtune/utils/test_env.py @@ -63,9 +63,6 @@ def _test_worker_fn(init_pg_explicit: bool) -> torch.device: raise AssertionError( f"Expected different process group backend: received {pg_backend}, expected {expected_pg_backend}" ) - world_size, rank = get_world_size_and_rank() - if world_size != 2: - raise AssertionError(f"Expected world size of 2, received {world_size}") return device @staticmethod From fa19d0fbec0af97bb4a199bf3ff1c13df6980dea Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Wed, 17 Jan 2024 07:54:22 -0800 Subject: [PATCH 6/7] revert --- torchtune/utils/device.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtune/utils/device.py b/torchtune/utils/device.py index a9adf4ce3a..9f8d0b9bcc 100644 --- a/torchtune/utils/device.py +++ b/torchtune/utils/device.py @@ -23,7 +23,6 @@ def _get_device_from_env() -> torch.device: Returns: device """ - return torch.device("cpu") if torch.cuda.is_available(): local_rank = int(os.environ.get("LOCAL_RANK", "0")) if local_rank >= torch.cuda.device_count(): From 4d35009b4708b30da73ae0cec41a987f2a070cad Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Wed, 17 Jan 2024 08:01:44 -0800 Subject: [PATCH 7/7] Add coment --- recipes/finetune_llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes/finetune_llm.py b/recipes/finetune_llm.py index 6cf5197505..e0e6feeea6 100644 --- a/recipes/finetune_llm.py +++ b/recipes/finetune_llm.py @@ -64,6 +64,7 @@ def recipe(kwargs): ) # ---- Initialize seed ---- # + # Fetch world size and rank after distributed process group initialization world_size, rank = get_world_size_and_rank() if kwargs["seed"] is not None: # Ensure that seed is different per rank (and its dataloader workers)