Skip to content

Commit 3c84ce0

Browse files
authored
Refactor PP splitting (#1416)
This refactors the PP splitting logic to consolidate around settings FQNs for each model chunk. For example: ``` [ ['tok_embeddings', 'layers.0'], # stage0 ['layers.1', 'layers.2'], # stage1 ['layers.3', 'layers.4'], # stage2 ... # so on... ] ``` This is better because it can generally be applied to all models, and the code can be re-used for cases that don't explicitly require pipelined execution (for example, streaming diloco needs to communicate model chunks) Changes: - Refactor deepseekv3 and llama to share the same pipeline util functions - Add module_names_per_model_chunk config, deprecate pipeline_parallel_split_points TODO (follow up PRs): - `pipeline_module_split` will be upstreamed to PyTorch as a `torch.distributed.pipelining` utility since it contains no model specific code. - Additional changes are needed to get this to work for torchft streaming diloco including updating the training loop to not execute if the pipeline schedule isn't set and making sure the pipelining_fn return the correct model chunks. cc @tushar00jain
1 parent f1c8c2c commit 3c84ce0

File tree

6 files changed

+400
-550
lines changed

6 files changed

+400
-550
lines changed

tests/unit_tests/test_job_config.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,73 +52,78 @@ def test_job_config_file_cmd_overrides(self):
5252
)
5353
assert config.job.dump_folder == "/tmp/test_tt/"
5454

55-
def test_parse_pp_split_points(self):
56-
toml_splits = ["layers.2", "layers.4", "layers.6"]
57-
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
58-
# no split points specified
59-
config_manager = ConfigManager()
60-
config = config_manager.parse_args(
61-
[
62-
"--job.config_file",
63-
"./torchtitan/models/llama3/train_configs/debug_model.toml",
64-
]
65-
)
66-
assert config.parallelism.pipeline_parallel_split_points == []
55+
def test_parse_module_fqns_per_model_part(self):
56+
toml_chunks = [
57+
["tok_embeddings", "layers.0"],
58+
["layers.1", "layers.2"],
59+
["layers.3", "norm", "output"],
60+
]
61+
cmdline_chunks = [
62+
["tok_embeddings", "layers.0", "layers.1"],
63+
["layers.2", "layers.3", "norm", "output"],
64+
]
6765

68-
# toml has no split points, but cmdline splits are specified
66+
# no module names specified
6967
config_manager = ConfigManager()
7068
config = config_manager.parse_args(
7169
[
7270
"--job.config_file",
7371
"./torchtitan/models/llama3/train_configs/debug_model.toml",
74-
"--parallelism.pipeline_parallel_split_points",
75-
",".join(cmdline_splits),
7672
]
7773
)
78-
assert (
79-
config.parallelism.pipeline_parallel_split_points == cmdline_splits
80-
), config.parallelism.pipeline_parallel_split_points
74+
assert config.parallelism.module_fqns_per_model_part is None
8175

82-
# toml has split points, cmdline does not
76+
# toml has module names, cmdline does not
8377
with tempfile.NamedTemporaryFile() as fp:
8478
with open(fp.name, "wb") as f:
8579
tomli_w.dump(
8680
{
8781
"parallelism": {
88-
"pipeline_parallel_split_points": toml_splits,
82+
"module_fqns_per_model_part": toml_chunks,
8983
}
9084
},
9185
f,
9286
)
9387
config_manager = ConfigManager()
9488
config = config_manager.parse_args(["--job.config_file", fp.name])
9589
assert (
96-
config.parallelism.pipeline_parallel_split_points == toml_splits
97-
), config.parallelism.pipeline_parallel_split_points
90+
config.parallelism.module_fqns_per_model_part == toml_chunks
91+
), config.parallelism.module_fqns_per_model_part
9892

99-
# toml has split points, cmdline overrides them
93+
# test that the field accepts list of lists structure
10094
with tempfile.NamedTemporaryFile() as fp:
10195
with open(fp.name, "wb") as f:
10296
tomli_w.dump(
10397
{
10498
"parallelism": {
105-
"pipeline_parallel_split_points": toml_splits,
99+
"module_fqns_per_model_part": cmdline_chunks,
106100
}
107101
},
108102
f,
109103
)
110104
config_manager = ConfigManager()
111-
config = config_manager.parse_args(
112-
[
113-
"--job.config_file",
114-
fp.name,
115-
"--parallelism.pipeline_parallel_split_points",
116-
",".join(cmdline_splits),
117-
]
118-
)
105+
config = config_manager.parse_args(["--job.config_file", fp.name])
106+
assert (
107+
config.parallelism.module_fqns_per_model_part == cmdline_chunks
108+
), config.parallelism.module_fqns_per_model_part
109+
110+
# test empty chunks are handled correctly
111+
empty_chunks = [[], ["tok_embeddings"], []]
112+
with tempfile.NamedTemporaryFile() as fp:
113+
with open(fp.name, "wb") as f:
114+
tomli_w.dump(
115+
{
116+
"parallelism": {
117+
"module_fqns_per_model_part": empty_chunks,
118+
}
119+
},
120+
f,
121+
)
122+
config_manager = ConfigManager()
123+
config = config_manager.parse_args(["--job.config_file", fp.name])
119124
assert (
120-
config.parallelism.pipeline_parallel_split_points == cmdline_splits
121-
), config.parallelism.pipeline_parallel_split_points
125+
config.parallelism.module_fqns_per_model_part == empty_chunks
126+
), config.parallelism.module_fqns_per_model_part
122127

123128
def test_parse_exclude_from_loading(self):
124129
toml_splits = ["optimizer", "dataloader"]

torchtitan/config/job_config.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class Parallelism:
290290

291291
pipeline_parallel_split_points: list[str] = field(default_factory=list)
292292
"""
293+
DEPRECATED: Use module_fqns_per_model_part instead.
293294
Specify comma-separated names of modules to use as the beginning of a split point.
294295
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
295296
the first containing all the layers up to layers.0,
@@ -299,9 +300,31 @@ class Parallelism:
299300
but currently the split points must be specified manually.
300301
"""
301302

303+
module_fqns_per_model_part: list[list[str]] | None = None
304+
"""
305+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
306+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
307+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
308+
will create 3 chunks: the first containing tok_embeddings and layers.0,
309+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
310+
This provides more explicit control over which modules belong to each chunk compared to split points.
311+
"""
312+
313+
pipeline_parallel_first_stage_less_layers: int = 1
314+
"""
315+
The number of layers to reduce in the first stage of pipeline parallelism. This is because
316+
the first stage has the extra overhead of the embedding layer, which is not present in the other stages.
317+
"""
318+
319+
pipeline_parallel_last_stage_less_layers: int = 1
320+
"""
321+
The number of layers to reduce in the last stage of pipeline parallelism. This is because
322+
the last stage has the extra overhead of the output layer, which is not present in the other stages.
323+
"""
324+
302325
pipeline_parallel_layers_per_stage: int | None = None
303326
"""
304-
The number of layers per (virtual) pipeline stage. If specified, the split points will be
327+
The number of layers per (virtual) pipeline stage. If specified, the module_fqns_per_model_part will be
305328
calculated from the number of layers and pipeline_parallel_degree. If not specified, the
306329
layers per stage will be inferred from the model, schedule, and pipeline_parallel_degree.
307330
"""

0 commit comments

Comments
 (0)