-
Notifications
You must be signed in to change notification settings - Fork 611
[Local Tensor] Replace dry_run.py with fake mode implementation #2057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/44/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,12 +258,56 @@ def maybe_enable_amp( | |
| ) | ||
|
|
||
|
|
||
| def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): | ||
| """Initialize fake backend | ||
|
|
||
| Args: | ||
| world_size: The number of GPUs to simulate | ||
| comm_mode: Communication mode ("fake_backend" or "local_tensor") | ||
|
|
||
| Returns: | ||
| The world size | ||
| """ | ||
| torch.distributed.init_process_group( | ||
| "fake", | ||
| rank=0, | ||
| world_size=world_size, | ||
| ) | ||
|
|
||
| # If local_tensor mode is enabled, initialize LocalTensorMode context | ||
| if comm_mode == "local_tensor": | ||
| from torch.distributed import _local_tensor | ||
|
|
||
| lm = _local_tensor.LocalTensorMode(world_size) | ||
| lm.__enter__() | ||
|
|
||
| # TODO: remove this once the root cause is figured out | ||
| torch.manual_seed(42) | ||
|
|
||
| return world_size | ||
fegin marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to |
||
|
|
||
|
|
||
| def init_distributed( | ||
| comm_config: CommConfig, | ||
| enable_cpu_backend: bool = False, | ||
| base_folder: str = "", | ||
| ranks: list[int] | None = None, | ||
| ): | ||
| ) -> int: | ||
| if comm_config.comm_mode in ("fake_backend", "local_tensor"): | ||
| ngpu_str = os.environ.get("NGPU") | ||
| if ngpu_str is None: | ||
| raise ValueError( | ||
| f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}" | ||
| ) | ||
| try: | ||
| world_size = int(ngpu_str) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| f"NGPU environment variable must be a valid integer, got: {ngpu_str}" | ||
| ) from e | ||
| init_fake_mode(world_size, comm_config.comm_mode) | ||
| return world_size | ||
|
|
||
| def _warn_overwrite_env(env, val): | ||
| if env in os.environ: | ||
| logger.warning( | ||
|
|
@@ -309,6 +353,8 @@ def _get_distributed_backend(enable_cpu_backend): | |
| _ranks=ranks if ranks is not None else [], | ||
| ) | ||
|
|
||
| return torch.distributed.get_world_size() | ||
|
|
||
|
|
||
| def set_pg_timeouts(timeout, world_mesh): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -360,15 +360,13 @@ def __init__(self, job_config: JobConfig): | |
|
|
||
| def init_distributed(self) -> ParallelDims: | ||
| job_config = self.job_config | ||
| dist_utils.init_distributed( | ||
| world_size = dist_utils.init_distributed( | ||
| job_config.comm, | ||
| enable_cpu_backend=job_config.training.enable_cpu_offload, | ||
| base_folder=job_config.job.dump_folder, | ||
| ) | ||
|
|
||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| parallelism_config = job_config.parallelism | ||
|
|
||
| return ParallelDims( | ||
| dp_shard=parallelism_config.data_parallel_shard_degree, | ||
| dp_replicate=parallelism_config.data_parallel_replicate_degree, | ||
|
|
@@ -718,6 +716,13 @@ def main(trainer_class: type[Trainer]) -> None: | |
| try: | ||
| trainer = trainer_class(config) | ||
|
|
||
| # TODO(local_tensor): Remove this special case once LocalTensor supports | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly, can we remove this now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are still some gaps. I updated the comment. |
||
| # init_weights() and foreach_allgather. In local tensor mode, skip | ||
| # training/checkpointing as the # model is not fully initialized | ||
| if config.comm.comm_mode == "local_tensor": | ||
| logger.info("Local tensor mode enabled - skipping training execution") | ||
| return | ||
|
|
||
| if config.checkpoint.create_seed_checkpoint: | ||
| assert ( | ||
| int(os.environ["WORLD_SIZE"]) == 1 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
comm.comm_modesounds redundant, I would proposeWDYT?