diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7fc5098800..8eec355b69 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -340,6 +340,11 @@ def build_optimizers_with_moe_load_balancing( ft_manager=ft_manager, ) + def should_manual_allreduce(tokens_per_expert_by_layer): + return not isinstance( + tokens_per_expert_by_layer, torch.distributed.tensor.DTensor + ) + def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: for transformer_block in model_part.layers.values(): @@ -380,11 +385,14 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if should_manual_allreduce(tokens_per_expert_by_layer): + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, + group=pg, + op=torch.distributed.ReduceOp.SUM, + ) moe_layer_idx = 0 with torch.no_grad(): diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..5918d2918d 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -865,6 +865,11 @@ class Experimental: needs to ensure that the path can be imported. """ + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" + + autop_force_bf16: bool = False + @dataclass class Validation: diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index db3a44a824..68c717717d 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -13,5 +13,7 @@ "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", "transformers_backend", + "auto_parallel.llama3", + "auto_parallel.deepseek_v3", ] ) diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 0000000000..7e112329b9 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,11 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +Use simplefsdp's autobucketing pass: + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py new file mode 100644 index 0000000000..7aa7f98f9e --- /dev/null +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + model_args = copy.deepcopy(deepseekv3_args) + + default_args = DeepSeekV3ModelArgs() + for config, args in model_args.items(): + if "flex_attn" in config: + continue + + use_flex_attn = (default_args.use_flex_attn,) + attn_mask_type = (default_args.attn_mask_type,) + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..89092dec64 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +import types +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Replicate, Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.models.moe.moe import _run_experts_grouped_mm + +from torchtitan.tools.logging import logger + + +def create_functional_router_forward( + self: nn.Module, +) -> Callable: # TokenChoiceTopKRouter + def functional_router_forward( + x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor + ): + # scores shape (bs*slen, num_experts) + scores = F.linear(x, gate_weight) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.score_func == "sigmoid": + scores = torch.sigmoid(scores.to(torch.float32)) + elif self.score_func == "softmax": + scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_func}") + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + # debug override: balanced round-robin routing + if self._debug_force_load_balance: + ( + selected_experts_indices, + top_scores, + ) = self._debug_force_load_balance_routing(scores) + + if self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + return top_scores, selected_experts_indices, num_tokens_per_expert + + return functional_router_forward + + +def _moe_forward( + x: torch.Tensor, + router_gate_weight: torch.nn.Parameter, + expert_bias: Optional[torch.Tensor], + experts_w1: torch.Tensor, + experts_w3: torch.Tensor, + experts_w2: torch.Tensor, + shared_w1_weight: torch.Tensor, + shared_w3_weight: torch.Tensor, + shared_w2_weight: torch.Tensor, + functional_router_forward: Callable, + reorderer: nn.Module, # TokenReorderer +): + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = functional_router_forward(x, router_gate_weight, expert_bias) + num_tokens_per_expert_update = num_tokens_per_expert + + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand( + -1, dim + ) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + # DSv3 score_before_experts is always False + # if score_before_experts: + # routed_input = ( + # routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + # ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + # routed_output = experts(routed_input, num_tokens_per_expert) + routed_output = _run_experts_grouped_mm( + experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert + ) + + # shared expert + # Note: we execute the shared expert before scoring the output of the routed expert + # to "implicitly" overlap the shared expert compute with token combine communication + # if shared_experts is not None: + # out = shared_experts(x) + _h1 = F.linear(x, shared_w1_weight) + _h3 = F.linear(x, shared_w3_weight) + out = F.linear(F.silu(_h1) * _h3, shared_w2_weight) + # else: + # out = torch.zeros_like(x) + + # DSv3 score_before_experts is False + # if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + out = out.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert_update + + +def moe_forward(self, x: torch.Tensor) -> torch.Tensor: + functional_router_forward = create_functional_router_forward(self.router) + out, num_tokens_per_expert = _moe_forward( + x, + self.router.gate.weight, + self.expert_bias, + self.experts.w1, + self.experts.w3, + self.experts.w2, + self.shared_experts.w1.weight, + self.shared_experts.w3.weight, + self.shared_experts.w2.weight, + functional_router_forward, + self.reorderer, + ) + # HOPs don't support buffer mutations, keep this outside + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + return out + + +def monkey_patch_checks(moe): + # causes data-dependent issue, hardcoded into monkey patch + assert not moe.score_before_experts + assert moe.router.gate.bias is None + assert moe.experts.use_grouped_mm + assert moe.shared_experts is not None + assert moe.shared_experts.w1.bias is None + assert moe.shared_experts.w2.bias is None + assert moe.shared_experts.w3.bias is None + assert not list(moe.reorderer.parameters()) + assert not list(moe.reorderer.buffers()) + + +def monkey_patch_local_map_moe(model, world_mesh): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + """ + from torch.distributed._tensor.experimental import local_map + + # from torchtitan.models.moe import moe + global _moe_forward + _moe_forward = local_map( + _moe_forward, + out_placements=( + (Replicate(),), # out: torch.Tensor + (Replicate(),), # num_tokens_per_expert_update: torch.Tensor + ), + in_placements=( + (Replicate(),), # x: torch.Tensor, + (Replicate(),), # router_gate_weight: torch.nn.Parameter, + (Replicate(),), # expert_bias: Optional[torch.Tensor], + (Replicate(),), # experts_w1: torch.Tensor, + (Replicate(),), # experts_w3: torch.Tensor, + (Replicate(),), # experts_w2: torch.Tensor, + (Replicate(),), # shared_w1: torch.Tensor, + (Replicate(),), # shared_w3: torch.Tensor, + (Replicate(),), # shared_w2: torch.Tensor, + None, # functional_router_forward: Callable, + None, # reorderer: TokenReorderer, + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=world_mesh, + ) + + for block in model.layers.children(): + if not block.moe_enabled: + continue + block.moe.forward = types.MethodType(moe_forward, block.moe) + monkey_patch_checks(block.moe) + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +# Run workflow with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # apply local_map to MoE + monkey_patch_local_map_moe(model, world_mesh) + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/auto_parallel/llama3/__init__.py b/torchtitan/experiments/auto_parallel/llama3/__init__.py new file mode 100644 index 0000000000..f9e61ddd7e --- /dev/null +++ b/torchtitan/experiments/auto_parallel/llama3/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.llama3 import llama3_args, Transformer +from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_llama import parallelize_llama + + +# CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml +# ./run_train.sh --model.name auto_parallel.llama3 --parallelism.tensor_parallel_degree 4 + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py new file mode 100644 index 0000000000..1d2bee4351 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/llama3/parallelize_llama.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + + # bail out + # model = model_fn() + # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() + + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..7edfe66979 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -16,6 +16,9 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 +PERFETTO_UI_ROOT_URL = ( + "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" +) @contextlib.contextmanager @@ -47,12 +50,25 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json") + prof.export_chrome_trace(output_file) - logger.info( - f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" - ) + log_str = f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + # not directly landable on upstream titan, + # but conveniently prints the internal url for perfetto on manifold for mast jobs + manifold_mount_prefix = "/mnt/mffuse/" + if output_file.find(manifold_mount_prefix) == 0: + manifold_path = os.path.join( + "torchtrain_datasets/tree", + output_file.split(manifold_mount_prefix)[1], + ) + perfetto_url = ( + PERFETTO_UI_ROOT_URL + + "#!/?url=https://interncache-all.fbcdn.net/manifold/" + + manifold_path + ) + log_str += f": {perfetto_url}" + logger.info(log_str) logger.info(f"Profiling active. Traces will be saved at {trace_dir}")