diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index a75b3a17b2..49e0590cbd 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -12,6 +12,11 @@ Joint Graph based Training Prototype: ## DeepSeek v3 +**SimpleFSDP + EP** +```shell +NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none +``` + **SimpleFSDP + TP + EP** ```shell NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none @@ -24,6 +29,11 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom ## llama3 +**SimpleFSDP** +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=8 +``` + **SimpleFSDP + TP** ```shell NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..fd1f9622e2 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -25,12 +25,17 @@ def disable_compile(job_config: JobConfig): def parallelize_inputs(world_mesh, args, kwargs): - def to_dtensor(tensor): - if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) - return tensor + if "tp" in world_mesh.mesh_dim_names: - dt_args = tree_map(to_dtensor, args) + def to_dtensor(tensor): + if isinstance(tensor, torch.Tensor): + return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return tensor + + dt_args = tree_map(to_dtensor, args) + else: + # TODO: When there is no TP (SimpleFSDP only), it currently only supports plain tensor inputs + dt_args = args # TODO: When using flex_attention, BlockMask would show up in kwargs, # and it's unclear how to convert it to DTensor. If I use to_dtensor, diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 4ff6c8187b..a246ac51df 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -15,7 +15,6 @@ JointWithDescriptors, ) from torch._guards import tracing, TracingContext -from torch.distributed.tensor import DTensor from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -93,8 +92,10 @@ def joint_graph_builder( joint_custom_pass: Optional custom pass to run on the joint graph """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + + # TODO: Enable this when we have full-DTensorize inputs support of SimpleFSDP + # for arg in model_args: + # assert isinstance(arg, DTensor) # get joint graph (