Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchtitan/experiments/compiler_toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Joint Graph based Training Prototype:

## DeepSeek v3

**SimpleFSDP + EP**
```shell
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
```

**SimpleFSDP + TP + EP**
```shell
NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none
Expand All @@ -24,6 +29,11 @@ NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.tom

## llama3

**SimpleFSDP**
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=8
```

**SimpleFSDP + TP**
```shell
NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4
Expand Down
15 changes: 10 additions & 5 deletions torchtitan/experiments/compiler_toolkit/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ def disable_compile(job_config: JobConfig):


def parallelize_inputs(world_mesh, args, kwargs):
def to_dtensor(tensor):
if isinstance(tensor, torch.Tensor):
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
return tensor
if "tp" in world_mesh.mesh_dim_names:

dt_args = tree_map(to_dtensor, args)
def to_dtensor(tensor):
if isinstance(tensor, torch.Tensor):
return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()])
return tensor

dt_args = tree_map(to_dtensor, args)
else:
# TODO: When there is no TP (SimpleFSDP only), it currently only supports plain tensor inputs
dt_args = args

# TODO: When using flex_attention, BlockMask would show up in kwargs,
# and it's unclear how to convert it to DTensor. If I use to_dtensor,
Expand Down
7 changes: 4 additions & 3 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
JointWithDescriptors,
)
from torch._guards import tracing, TracingContext
from torch.distributed.tensor import DTensor
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -93,8 +92,10 @@ def joint_graph_builder(
joint_custom_pass: Optional custom pass to run on the joint graph
"""
assert isinstance(model_args, tuple)
for arg in model_args:
assert isinstance(arg, DTensor)

# TODO: Enable this when we have full-DTensorize inputs support of SimpleFSDP
# for arg in model_args:
# assert isinstance(arg, DTensor)

# get joint graph
(
Expand Down