Skip to content

Commit 2926160

Browse files
committed
model fragments for diloco
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" />
1 parent f3e2a75 commit 2926160

File tree

6 files changed

+277
-5
lines changed

6 files changed

+277
-5
lines changed

torchtitan/config/job_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,17 @@ class FaultTolerance:
661661
This is only used when "semi_sync_method" is set.
662662
"""
663663

664+
module_names_per_model_chunk: list[list[str]] = field(default_factory=list)
665+
"""
666+
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
667+
Each inner list represents one model chunk and contains the module names that belong to that chunk.
668+
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
669+
will create 3 chunks: the first containing tok_embeddings and layers.0,
670+
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
671+
This provides more explicit control over which modules belong to each chunk compared to split points.
672+
"""
673+
674+
num_fragments: int = 1
664675

665676
@dataclass
666677
class Experimental:

torchtitan/distributed/pipeline.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import os
7+
import nn
78
from typing import Callable
89

910
from torch.distributed.pipelining.schedules import (
@@ -19,7 +20,7 @@
1920
from torchtitan.tools.logging import logger
2021

2122

22-
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
23+
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank", "generate_module_names_per_stage", "module_split"]
2324

2425

2526
# TODO: It's unclear if this API is general enough to be used by other models.
@@ -209,3 +210,193 @@ def stage_ids_this_rank(
209210
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
210211
)
211212
return stage_v_pairs[pp_rank]
213+
214+
215+
def generate_module_names_per_stage(
216+
num_stages: int,
217+
num_layers: int,
218+
input_weight: int = 1,
219+
output_weight: int = 1,
220+
) -> list[list[str]]:
221+
"""
222+
Programmatically generates module names per stage for pipeline parallelism with weighting.
223+
224+
Args:
225+
num_stages: Number of pipeline stages
226+
num_layers: Total number of transformer layers in the model
227+
input_weight: Weight for input modules (tok_embeddings) in layer calculation
228+
output_weight: Weight for output modules (norm + output) in layer calculation
229+
230+
Returns:
231+
List of lists containing module names for each stage
232+
233+
Example:
234+
generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2)
235+
treats embeddings as 2 layers and norm+output as 2 layers for distribution
236+
"""
237+
if num_stages < 1:
238+
raise ValueError("Number of stages must be at least 1")
239+
240+
if num_stages == 1:
241+
# Single stage gets everything
242+
layer_names = [f"layers.{i}" for i in range(num_layers)]
243+
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
244+
245+
# Calculate effective layers including weights
246+
num_effective_layers = num_layers + input_weight + output_weight
247+
248+
if num_stages > num_effective_layers:
249+
raise ValueError(
250+
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
251+
)
252+
253+
# Calculate layers per stage (distribute evenly)
254+
layers_per_stage = num_effective_layers // num_stages
255+
extra_layers = num_effective_layers % num_stages
256+
257+
# Ensure each stage gets at least the weight of input/output modules
258+
if layers_per_stage < max(input_weight, output_weight):
259+
raise ValueError(
260+
f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})"
261+
)
262+
263+
module_names_per_stage = []
264+
current_layer = 0
265+
266+
for stage_idx in range(num_stages):
267+
stage_modules = []
268+
269+
# Calculate effective layers for this stage
270+
effective_layers_for_stage = layers_per_stage
271+
if stage_idx < extra_layers:
272+
effective_layers_for_stage += 1
273+
274+
# First stage: handle input modules with weighting
275+
if stage_idx == 0:
276+
stage_modules.append("tok_embeddings")
277+
# Account for input weight in layer distribution
278+
remaining_layers_for_stage = effective_layers_for_stage - input_weight
279+
280+
# Add transformer layers
281+
for _ in range(remaining_layers_for_stage):
282+
if current_layer < num_layers:
283+
stage_modules.append(f"layers.{current_layer}")
284+
current_layer += 1
285+
286+
# Last stage: handle output modules with weighting
287+
elif stage_idx == num_stages - 1:
288+
# Account for output weight in layer distribution
289+
remaining_layers_for_stage = effective_layers_for_stage - output_weight
290+
291+
# Add transformer layers
292+
for _ in range(remaining_layers_for_stage):
293+
if current_layer < num_layers:
294+
stage_modules.append(f"layers.{current_layer}")
295+
current_layer += 1
296+
297+
# Add output modules
298+
stage_modules.extend(["norm", "output"])
299+
300+
# Middle stages: only transformer layers
301+
else:
302+
for _ in range(effective_layers_for_stage):
303+
if current_layer < num_layers:
304+
stage_modules.append(f"layers.{current_layer}")
305+
current_layer += 1
306+
307+
module_names_per_stage.append(stage_modules)
308+
309+
return module_names_per_stage
310+
311+
312+
def module_split(
313+
model: nn.Module,
314+
module_names_per_stage: list[list[str]],
315+
) -> list[nn.Module]:
316+
"""
317+
This API creates pipeline stages based on specified module names for each stage.
318+
This method updates the model in place.
319+
320+
Args:
321+
model: The complete model to be split
322+
module_names_per_stage: List of lists, where each inner list contains the module names
323+
that should be included in that stage. Module names should be
324+
dot-separated paths. Examples:
325+
- "tok_embeddings" for token embeddings
326+
- "layers.0", "layers.1" for specific transformer layers
327+
- "norm" for the final normalization layer
328+
- "output" for the output projection layer
329+
330+
Returns:
331+
List of model chunks
332+
333+
Example usage:
334+
module_names_per_stage = [
335+
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
336+
["layers.1", "layers.2"], # Stage 1: middle layers
337+
["norm", "output"] # Stage 2: final norm + output
338+
]
339+
"""
340+
def _build_stage_from_modules(
341+
stage_idx: int, module_names: list[str]
342+
) -> nn.Module:
343+
stage_model = nn.Module()
344+
# Create a set of modules to keep for faster lookup
345+
modules_to_keep = set(module_names)
346+
print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}")
347+
for module_name, module_value in model.named_children():
348+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
349+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
350+
layers_to_keep = {
351+
name.split(".", 1)[1]
352+
for name in modules_to_keep
353+
if name.startswith(f"{module_name}.")
354+
}
355+
356+
if not layers_to_keep:
357+
continue
358+
359+
# Keep only specified layers
360+
if isinstance(module_value, nn.ModuleDict):
361+
for layer_name in list(module_value.keys()):
362+
if layer_name in layers_to_keep:
363+
setattr(stage_model, f"{module_name}.{layer_name}", module_value[layer_name])
364+
else:
365+
indices_to_keep = {
366+
int(idx) for idx in layers_to_keep if idx.isdigit()
367+
}
368+
new_layers = nn.ModuleList(
369+
[
370+
layer
371+
for i, layer in enumerate(module_value)
372+
if i in indices_to_keep
373+
]
374+
)
375+
setattr(stage_model, module_name, new_layers)
376+
377+
continue
378+
379+
# Handle simple module attributes (e.g., "linear", "norm")
380+
if module_name not in modules_to_keep:
381+
continue
382+
383+
setattr(stage_model, module_name, module_value)
384+
385+
return stage_model
386+
387+
num_stages = len(module_names_per_stage)
388+
models = []
389+
390+
for stage_idx in range(num_stages):
391+
module_names = module_names_per_stage[stage_idx]
392+
model_chunk = _build_stage_from_modules(
393+
stage_idx,
394+
module_names,
395+
)
396+
logger.info(
397+
f"building stage_idx {stage_idx} "
398+
f"with modules {module_names}"
399+
)
400+
models.append(model_chunk)
401+
402+
return models

torchtitan/models/llama3/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from torchtitan.components.tokenizer import build_hf_tokenizer
1111
from torchtitan.components.validate import build_validator
1212
from torchtitan.datasets.hf_datasets import build_hf_dataloader
13-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
13+
from torchtitan.protocols.train_spec import FaultTolerantTrainSpec, register_train_spec
1414

1515
from .infra.parallelize import parallelize_llama
1616
from .infra.pipeline import pipeline_llama
17+
from .infra.fault_tolerance import fragment_llama
1718
from .model.args import TransformerModelArgs
1819
from .model.model import Transformer
1920
from .model.state_dict_adapter import Llama3StateDictAdapter
@@ -71,12 +72,13 @@
7172

7273

7374
register_train_spec(
74-
TrainSpec(
75+
FaultTolerantTrainSpec(
7576
name="llama3",
7677
model_cls=Transformer,
7778
model_args=llama3_configs,
7879
parallelize_fn=parallelize_llama,
7980
pipelining_fn=pipeline_llama,
81+
fragment_fn=fragment_llama,
8082
build_optimizers_fn=build_optimizers,
8183
build_lr_schedulers_fn=build_lr_schedulers,
8284
build_dataloader_fn=build_hf_dataloader,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file is used to setup the model for fault tolerance
8+
9+
import torch.nn as nn
10+
11+
12+
from torchtitan.config_manager import JobConfig
13+
from torchtitan.distributed.pipeline import (
14+
generate_module_names_per_stage,
15+
module_split,
16+
)
17+
from torchtitan.tools.logging import logger
18+
19+
from ..model.args import TransformerModelArgs
20+
21+
def fragment_llama(
22+
model: nn.Module,
23+
job_config: JobConfig,
24+
model_config: TransformerModelArgs,
25+
) -> list[nn.Module]:
26+
ft = job_config.fault_tolerance
27+
28+
assert ft.num_fragments > 0
29+
30+
module_names_per_stage = ft.module_names_per_model_chunk
31+
32+
input_weight = 1 # Weight for tok_embeddings
33+
output_weight = 1 # Weight for norm + output layers
34+
35+
if module_names_per_stage == []:
36+
if ft.num_fragments == 1:
37+
return [model]
38+
39+
module_names_per_stage = generate_module_names_per_stage(
40+
ft.num_fragments, model_config.n_layers, input_weight, output_weight
41+
)
42+
43+
44+
model_fragments = module_split(model, module_names_per_stage)
45+
print(f"Created {len(model_fragments)} model fragments")
46+
47+
return model_fragments

torchtitan/protocols/train_spec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ class TrainSpec:
5656
state_dict_adapter: type[StateDictAdapter] | None = None
5757

5858

59+
FragmentFunction: TypeAlias = Callable[
60+
..., list[nn.Module]
61+
]
62+
63+
64+
@dataclass
65+
class FaultTolerantTrainSpec(TrainSpec):
66+
fragment_fn: FragmentFunction | None = None
67+
68+
5969
_train_specs = {}
6070

6171

torchtitan/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11-
from typing import Any, Generator, Iterable, Optional
11+
from typing import Any, Generator, Iterable, Optional, cast
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
@@ -43,6 +43,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful):
4343
tokenizer: train_spec_module.BaseTokenizer | None
4444
dataloader: train_spec_module.BaseDataLoader
4545
model_parts: list[torch.nn.Module]
46+
ft_model_parts: list[torch.nn.Module]
4647
loss_fn: train_spec_module.LossFunction
4748
optimizers: train_spec_module.OptimizersContainer
4849
lr_schedulers: train_spec_module.LRSchedulersContainer
@@ -261,6 +262,16 @@ def __init__(self, job_config: JobConfig):
261262

262263
self.model_parts = [model]
263264

265+
ft = job_config.fault_tolerance
266+
267+
if ft.enable:
268+
train_spec = cast(train_spec_module.FaultTolerantTrainSpec, self.train_spec)
269+
if train_spec.fragment_fn:
270+
self.ft_model_parts = train_spec.fragment_fn(model, job_config, model_args)
271+
else:
272+
self.ft_model_parts = [model]
273+
274+
264275
self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)
265276

266277
# initialize device memory monitor and get peak flops for MFU calculation
@@ -524,7 +535,7 @@ def train(self):
524535
maybe_semi_sync_training(
525536
job_config.fault_tolerance,
526537
ft_manager=self.ft_manager,
527-
model_parts=self.model_parts,
538+
model_parts=self.ft_model_parts,
528539
optimizer=self.optimizers,
529540
),
530541
):

0 commit comments

Comments
 (0)