Skip to content

Commit 09f0c94

Browse files
committed
reuse pipeline from torchtitan
1 parent 0d80f62 commit 09f0c94

File tree

2 files changed

+25
-262
lines changed

2 files changed

+25
-262
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def generate_llm_fqn_per_model_part(
228228
num_layers: int,
229229
input_weight: int = 1,
230230
output_weight: int = 1,
231+
include_rotary_emb: bool = False,
231232
) -> list[list[str]]:
232233
"""
233234
Programmatically generates module names model part, focused on LLMs models.
@@ -237,6 +238,7 @@ def generate_llm_fqn_per_model_part(
237238
num_layers: Total number of transformer layers in the model
238239
input_weight: Weight for input modules (tok_embeddings) in layer calculation
239240
output_weight: Weight for output modules (norm + output) in layer calculation
241+
include_rotary_emb: Whether to include rotary_emb in each model part
240242
241243
Returns:
242244
List of lists containing module names for each model part
@@ -251,7 +253,10 @@ def generate_llm_fqn_per_model_part(
251253
if num_stages == 1:
252254
# Single stage gets everything
253255
layer_names = [f"layers.{i}" for i in range(num_layers)]
254-
return [["tok_embeddings"] + layer_names + ["norm", "output"]]
256+
result = [["tok_embeddings"] + layer_names + ["norm", "output"]]
257+
if include_rotary_emb:
258+
result[0].append("rotary_emb")
259+
return result
255260

256261
# Calculate effective layers including weights
257262
num_effective_layers = num_layers + input_weight + output_weight
@@ -329,6 +334,8 @@ def generate_llm_fqn_per_model_part(
329334
stage_modules.append(f"layers.{current_layer}")
330335
current_layer += 1
331336

337+
if include_rotary_emb:
338+
stage_modules.append("rotary_emb")
332339
module_names_per_stage.append(stage_modules)
333340

334341
return module_names_per_stage
@@ -340,6 +347,7 @@ def pipeline_module_split(
340347
pp_schedule: str,
341348
device: torch.device,
342349
module_names_per_stage: list[list[str]],
350+
use_identity_for_missing_modules: bool = False,
343351
) -> tuple[list[PipelineStage], list[nn.Module]]:
344352
"""
345353
This API creates pipeline stages based on specified module names for each stage.
@@ -361,6 +369,8 @@ def pipeline_module_split(
361369
- "layers.0", "layers.1" for specific transformer layers
362370
- "norm" for the final normalization layer
363371
- "output" for the output projection layer
372+
use_identity_for_missing_modules: If True, replace missing modules with nn.Identity(),
373+
otherwise replace with None
364374
365375
Returns:
366376
Tuple of (stages, models) where stages are PipelineStage objects and models are the
@@ -417,8 +427,9 @@ def _build_stage_from_modules(
417427
setattr(model, module_name, nn.ModuleList())
418428
# Handle simple module attributes (e.g., "linear", "norm")
419429
elif module_name not in modules_to_keep:
420-
# Replace with None
421-
setattr(model, module_name, None)
430+
# Replace with Identity or None based on configuration
431+
replacement = nn.Identity() if use_identity_for_missing_modules else None
432+
setattr(model, module_name, replacement)
422433

423434
stage = PipelineStage(
424435
model,

torchtitan/experiments/transformers_backend/infra/pipeline.py

Lines changed: 11 additions & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -3,280 +3,27 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import copy
76
import math
87

98
import torch
109
import torch.nn as nn
11-
from torch.distributed.device_mesh import DeviceMesh
12-
from torch.distributed.pipelining import PipelineStage
1310
from torch.distributed.pipelining.schedules import (
1411
_PipelineSchedule,
1512
get_schedule_class,
1613
PipelineScheduleSingle,
17-
ScheduleDualPipeV,
18-
ScheduleZBVZeroBubble,
1914
)
2015

2116
from torchtitan.components.loss import LossFunction
2217
from torchtitan.experiments.transformers_backend.job_config import JobConfig
2318
from torchtitan.distributed import ParallelDims
24-
from torchtitan.distributed.pipeline_parallel import build_pipeline_schedule
19+
from torchtitan.distributed.pipeline_parallel import (
20+
build_pipeline_schedule,
21+
generate_llm_fqn_per_model_part,
22+
pipeline_module_split,
23+
)
2524
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
2625
from torchtitan.tools.logging import logger
2726

28-
# NOTE(3outeille): the only modifications comes from replacing None to nn.Identity and adding rotary_emb per model_part
29-
30-
31-
def generate_llm_fqn_per_model_part(
32-
num_stages: int,
33-
num_layers: int,
34-
input_weight: int = 1,
35-
output_weight: int = 1,
36-
) -> list[list[str]]:
37-
"""
38-
Programmatically generates module names model part, focused on LLMs models.
39-
Args:
40-
num_stages: Number of pipeline stages
41-
num_layers: Total number of transformer layers in the model
42-
input_weight: Weight for input modules (embed_tokens) in layer calculation
43-
output_weight: Weight for output modules (norm + output) in layer calculation
44-
Returns:
45-
List of lists containing module names for each model part
46-
Example:
47-
generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2)
48-
treats embeddings as 2 layers and norm+output as 2 layers for distribution
49-
"""
50-
if num_stages < 1:
51-
raise ValueError("Number of stages must be at least 1")
52-
53-
if num_stages == 1:
54-
# Single stage gets everything
55-
layer_names = [f"layers.{i}" for i in range(num_layers)]
56-
return [["tok_embeddings"] + layer_names + ["norm", "output", "rotary_emb"]]
57-
58-
# Calculate effective layers including weights
59-
num_effective_layers = num_layers + input_weight + output_weight
60-
61-
if num_stages > num_effective_layers:
62-
raise ValueError(
63-
f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})"
64-
)
65-
66-
# Calculate layers per stage (distribute evenly)
67-
layers_per_stage = num_effective_layers // num_stages
68-
extra_layers = num_effective_layers % num_stages
69-
70-
# Feasibility check: Ensure at least 1 layer in each PP stage
71-
if layers_per_stage == 0:
72-
raise ValueError(
73-
f"Configuration would result in empty stages. "
74-
f"With {num_stages} stages and {num_effective_layers} effective layers "
75-
f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), "
76-
f"each stage would get {layers_per_stage} layers on average. "
77-
f"Reduce num_stages or increase num_layers/weights."
78-
)
79-
80-
# Balance check: Ensure weights don't exceed minimum layers per stage
81-
if input_weight > layers_per_stage:
82-
raise ValueError(
83-
f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})."
84-
)
85-
if output_weight > layers_per_stage:
86-
raise ValueError(
87-
f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})."
88-
)
89-
90-
module_names_per_stage = []
91-
current_layer = 0
92-
93-
for stage_idx in range(num_stages):
94-
stage_modules = []
95-
96-
# Calculate effective layers for this stage
97-
effective_layers_for_stage = layers_per_stage
98-
if stage_idx < extra_layers:
99-
effective_layers_for_stage += 1
100-
101-
# First stage: handle input modules with weighting
102-
if stage_idx == 0:
103-
stage_modules.append("tok_embeddings")
104-
# Account for input weight in layer distribution
105-
remaining_layers_for_stage = effective_layers_for_stage - input_weight
106-
107-
# Add transformer layers
108-
for _ in range(remaining_layers_for_stage):
109-
if current_layer < num_layers:
110-
stage_modules.append(f"layers.{current_layer}")
111-
current_layer += 1
112-
113-
# Last stage: handle output modules with weighting
114-
elif stage_idx == num_stages - 1:
115-
# Account for output weight in layer distribution
116-
remaining_layers_for_stage = effective_layers_for_stage - output_weight
117-
118-
# Add transformer layers
119-
for _ in range(remaining_layers_for_stage):
120-
if current_layer < num_layers:
121-
stage_modules.append(f"layers.{current_layer}")
122-
current_layer += 1
123-
124-
# Add output modules
125-
stage_modules.extend(["norm", "output"])
126-
127-
# Middle stages: only transformer layers
128-
else:
129-
for _ in range(effective_layers_for_stage):
130-
if current_layer < num_layers:
131-
stage_modules.append(f"layers.{current_layer}")
132-
current_layer += 1
133-
134-
stage_modules.append("rotary_emb")
135-
module_names_per_stage.append(stage_modules)
136-
137-
return module_names_per_stage
138-
139-
140-
def pipeline_module_split(
141-
whole_model: nn.Module,
142-
pp_mesh: DeviceMesh,
143-
pp_schedule: str,
144-
device: torch.device,
145-
module_names_per_stage: list[list[str]],
146-
) -> tuple[list[PipelineStage], list[nn.Module]]:
147-
"""
148-
This API creates pipeline stages based on specified module names for each stage.
149-
150-
Some model restrictions include:
151-
- forward() method should tolerate deleted layers
152-
- weight initialization methods should tolerate deleted layers
153-
- Does not support nested moduledict and modulelist structures
154-
155-
Args:
156-
whole_model: The complete model to be split
157-
pp_mesh: Pipeline parallel device mesh
158-
pp_schedule: Name of pipeline parallelism schedule
159-
device: Device
160-
module_names_per_stage: List of lists, where each inner list contains the module names
161-
that should be included in that stage. Module names should be
162-
dot-separated paths. Examples:
163-
- "tok_embeddings" for token embeddings
164-
- "layers.0", "layers.1" for specific transformer layers
165-
- "norm" for the final normalization layer
166-
- "output" for the output projection layer
167-
168-
Returns:
169-
Tuple of (stages, models) where stages are PipelineStage objects and models are the
170-
corresponding model chunks
171-
172-
Example usage:
173-
module_names_per_stage = [
174-
["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
175-
["layers.1", "layers.2"], # Stage 1: middle layers
176-
["norm", "output"] # Stage 2: final norm + output
177-
]
178-
"""
179-
pp_rank = pp_mesh.get_local_rank()
180-
pp_degree = pp_mesh.size()
181-
182-
def _build_stage_from_modules(
183-
stage_idx: int, module_names: list[str], num_stages: int
184-
) -> tuple[PipelineStage, nn.Module]:
185-
model = copy.deepcopy(whole_model)
186-
187-
# Create a set of modules to keep for faster lookup
188-
modules_to_keep = set(module_names)
189-
for module_name, module_value in model.named_children():
190-
# Handle layer-like structures (e.g., "layers.0", "layers.1")
191-
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
192-
layers_to_keep = {
193-
name.split(".", 1)[1]
194-
for name in modules_to_keep
195-
if name.startswith(f"{module_name}.")
196-
}
197-
if layers_to_keep:
198-
# Keep only specified layers
199-
if isinstance(module_value, nn.ModuleDict):
200-
for layer_name in list(module_value.keys()):
201-
if layer_name not in layers_to_keep:
202-
del module_value[layer_name]
203-
elif isinstance(module_value, nn.ModuleList):
204-
indices_to_keep = {
205-
int(idx) for idx in layers_to_keep if idx.isdigit()
206-
}
207-
new_layers = nn.ModuleList(
208-
[
209-
layer
210-
for i, layer in enumerate(module_value)
211-
if i in indices_to_keep
212-
]
213-
)
214-
setattr(model, module_name, new_layers)
215-
else:
216-
# No layers from this structure needed, set to empty structure
217-
if isinstance(module_value, nn.ModuleDict):
218-
setattr(model, module_name, nn.ModuleDict())
219-
elif isinstance(module_value, nn.ModuleList):
220-
setattr(model, module_name, nn.ModuleList())
221-
# Handle simple module attributes (e.g., "linear", "norm")
222-
elif module_name not in modules_to_keep:
223-
# Replace with Identity
224-
setattr(model, module_name, nn.Identity())
225-
226-
stage = PipelineStage(
227-
model,
228-
stage_idx,
229-
num_stages,
230-
device,
231-
group=pp_mesh.get_group("pp"),
232-
)
233-
return stage, model
234-
235-
num_stages = len(module_names_per_stage)
236-
stages = []
237-
models = []
238-
239-
schedule_class = get_schedule_class(pp_schedule)
240-
style = (
241-
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
242-
)
243-
244-
def _get_stage_indices() -> tuple[int]:
245-
"""
246-
Compute the stage ids for the stages that will run on this pp rank
247-
for either a looped or V style schedule
248-
"""
249-
assert (
250-
num_stages % pp_degree == 0
251-
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
252-
stages_per_rank = num_stages // pp_degree
253-
if style == "loop":
254-
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
255-
elif style == "v":
256-
assert (
257-
stages_per_rank == 2
258-
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
259-
stage_v_pairs = list(
260-
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
261-
)
262-
return stage_v_pairs[pp_rank]
263-
264-
for stage_idx in _get_stage_indices():
265-
module_names = module_names_per_stage[stage_idx]
266-
stage, model_chunk = _build_stage_from_modules(
267-
stage_idx,
268-
module_names,
269-
num_stages,
270-
)
271-
logger.info(
272-
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
273-
f"with modules {module_names}"
274-
)
275-
stages.append(stage)
276-
models.append(model_chunk)
277-
278-
return stages, models
279-
28027

28128
def pipeline_hf_transformers(
28229
model: nn.Module,
@@ -355,7 +102,11 @@ def pipeline_hf_transformers(
355102
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
356103
if module_names_per_stage is None:
357104
module_names_per_stage = generate_llm_fqn_per_model_part(
358-
num_virtual_stages, num_layers, input_weight, output_weight
105+
num_virtual_stages,
106+
num_layers,
107+
input_weight,
108+
output_weight,
109+
include_rotary_emb=True,
359110
)
360111
for i, stage_ms in enumerate(module_names_per_stage):
361112
logger.debug(f"Stage {i}: {stage_ms}")
@@ -366,6 +117,7 @@ def pipeline_hf_transformers(
366117
job_config.parallelism.pipeline_parallel_schedule,
367118
device,
368119
module_names_per_stage,
120+
use_identity_for_missing_modules=True,
369121
)
370122

371123
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.

0 commit comments

Comments
 (0)