33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6- import copy
76import math
87
98import torch
109import torch .nn as nn
11- from torch .distributed .device_mesh import DeviceMesh
12- from torch .distributed .pipelining import PipelineStage
1310from torch .distributed .pipelining .schedules import (
1411 _PipelineSchedule ,
1512 get_schedule_class ,
1613 PipelineScheduleSingle ,
17- ScheduleDualPipeV ,
18- ScheduleZBVZeroBubble ,
1914)
2015
2116from torchtitan .components .loss import LossFunction
2217from torchtitan .experiments .transformers_backend .job_config import JobConfig
2318from torchtitan .distributed import ParallelDims
24- from torchtitan .distributed .pipeline_parallel import build_pipeline_schedule
19+ from torchtitan .distributed .pipeline_parallel import (
20+ build_pipeline_schedule ,
21+ generate_llm_fqn_per_model_part ,
22+ pipeline_module_split ,
23+ )
2524from torchtitan .protocols .train_spec import BaseModelArgs , ParallelizeFunction
2625from torchtitan .tools .logging import logger
2726
28- # NOTE(3outeille): the only modifications comes from replacing None to nn.Identity and adding rotary_emb per model_part
29-
30-
31- def generate_llm_fqn_per_model_part (
32- num_stages : int ,
33- num_layers : int ,
34- input_weight : int = 1 ,
35- output_weight : int = 1 ,
36- ) -> list [list [str ]]:
37- """
38- Programmatically generates module names model part, focused on LLMs models.
39- Args:
40- num_stages: Number of pipeline stages
41- num_layers: Total number of transformer layers in the model
42- input_weight: Weight for input modules (embed_tokens) in layer calculation
43- output_weight: Weight for output modules (norm + output) in layer calculation
44- Returns:
45- List of lists containing module names for each model part
46- Example:
47- generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2)
48- treats embeddings as 2 layers and norm+output as 2 layers for distribution
49- """
50- if num_stages < 1 :
51- raise ValueError ("Number of stages must be at least 1" )
52-
53- if num_stages == 1 :
54- # Single stage gets everything
55- layer_names = [f"layers.{ i } " for i in range (num_layers )]
56- return [["tok_embeddings" ] + layer_names + ["norm" , "output" , "rotary_emb" ]]
57-
58- # Calculate effective layers including weights
59- num_effective_layers = num_layers + input_weight + output_weight
60-
61- if num_stages > num_effective_layers :
62- raise ValueError (
63- f"Number of stages ({ num_stages } ) cannot be greater than effective layers ({ num_effective_layers } )"
64- )
65-
66- # Calculate layers per stage (distribute evenly)
67- layers_per_stage = num_effective_layers // num_stages
68- extra_layers = num_effective_layers % num_stages
69-
70- # Feasibility check: Ensure at least 1 layer in each PP stage
71- if layers_per_stage == 0 :
72- raise ValueError (
73- f"Configuration would result in empty stages. "
74- f"With { num_stages } stages and { num_effective_layers } effective layers "
75- f"(num_layers={ num_layers } + input_weight={ input_weight } + output_weight={ output_weight } ), "
76- f"each stage would get { layers_per_stage } layers on average. "
77- f"Reduce num_stages or increase num_layers/weights."
78- )
79-
80- # Balance check: Ensure weights don't exceed minimum layers per stage
81- if input_weight > layers_per_stage :
82- raise ValueError (
83- f"input_weight ({ input_weight } ) exceeds minimum layers per stage ({ layers_per_stage } )."
84- )
85- if output_weight > layers_per_stage :
86- raise ValueError (
87- f"output_weight ({ output_weight } ) exceeds minimum layers per stage ({ layers_per_stage } )."
88- )
89-
90- module_names_per_stage = []
91- current_layer = 0
92-
93- for stage_idx in range (num_stages ):
94- stage_modules = []
95-
96- # Calculate effective layers for this stage
97- effective_layers_for_stage = layers_per_stage
98- if stage_idx < extra_layers :
99- effective_layers_for_stage += 1
100-
101- # First stage: handle input modules with weighting
102- if stage_idx == 0 :
103- stage_modules .append ("tok_embeddings" )
104- # Account for input weight in layer distribution
105- remaining_layers_for_stage = effective_layers_for_stage - input_weight
106-
107- # Add transformer layers
108- for _ in range (remaining_layers_for_stage ):
109- if current_layer < num_layers :
110- stage_modules .append (f"layers.{ current_layer } " )
111- current_layer += 1
112-
113- # Last stage: handle output modules with weighting
114- elif stage_idx == num_stages - 1 :
115- # Account for output weight in layer distribution
116- remaining_layers_for_stage = effective_layers_for_stage - output_weight
117-
118- # Add transformer layers
119- for _ in range (remaining_layers_for_stage ):
120- if current_layer < num_layers :
121- stage_modules .append (f"layers.{ current_layer } " )
122- current_layer += 1
123-
124- # Add output modules
125- stage_modules .extend (["norm" , "output" ])
126-
127- # Middle stages: only transformer layers
128- else :
129- for _ in range (effective_layers_for_stage ):
130- if current_layer < num_layers :
131- stage_modules .append (f"layers.{ current_layer } " )
132- current_layer += 1
133-
134- stage_modules .append ("rotary_emb" )
135- module_names_per_stage .append (stage_modules )
136-
137- return module_names_per_stage
138-
139-
140- def pipeline_module_split (
141- whole_model : nn .Module ,
142- pp_mesh : DeviceMesh ,
143- pp_schedule : str ,
144- device : torch .device ,
145- module_names_per_stage : list [list [str ]],
146- ) -> tuple [list [PipelineStage ], list [nn .Module ]]:
147- """
148- This API creates pipeline stages based on specified module names for each stage.
149-
150- Some model restrictions include:
151- - forward() method should tolerate deleted layers
152- - weight initialization methods should tolerate deleted layers
153- - Does not support nested moduledict and modulelist structures
154-
155- Args:
156- whole_model: The complete model to be split
157- pp_mesh: Pipeline parallel device mesh
158- pp_schedule: Name of pipeline parallelism schedule
159- device: Device
160- module_names_per_stage: List of lists, where each inner list contains the module names
161- that should be included in that stage. Module names should be
162- dot-separated paths. Examples:
163- - "tok_embeddings" for token embeddings
164- - "layers.0", "layers.1" for specific transformer layers
165- - "norm" for the final normalization layer
166- - "output" for the output projection layer
167-
168- Returns:
169- Tuple of (stages, models) where stages are PipelineStage objects and models are the
170- corresponding model chunks
171-
172- Example usage:
173- module_names_per_stage = [
174- ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer
175- ["layers.1", "layers.2"], # Stage 1: middle layers
176- ["norm", "output"] # Stage 2: final norm + output
177- ]
178- """
179- pp_rank = pp_mesh .get_local_rank ()
180- pp_degree = pp_mesh .size ()
181-
182- def _build_stage_from_modules (
183- stage_idx : int , module_names : list [str ], num_stages : int
184- ) -> tuple [PipelineStage , nn .Module ]:
185- model = copy .deepcopy (whole_model )
186-
187- # Create a set of modules to keep for faster lookup
188- modules_to_keep = set (module_names )
189- for module_name , module_value in model .named_children ():
190- # Handle layer-like structures (e.g., "layers.0", "layers.1")
191- if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
192- layers_to_keep = {
193- name .split ("." , 1 )[1 ]
194- for name in modules_to_keep
195- if name .startswith (f"{ module_name } ." )
196- }
197- if layers_to_keep :
198- # Keep only specified layers
199- if isinstance (module_value , nn .ModuleDict ):
200- for layer_name in list (module_value .keys ()):
201- if layer_name not in layers_to_keep :
202- del module_value [layer_name ]
203- elif isinstance (module_value , nn .ModuleList ):
204- indices_to_keep = {
205- int (idx ) for idx in layers_to_keep if idx .isdigit ()
206- }
207- new_layers = nn .ModuleList (
208- [
209- layer
210- for i , layer in enumerate (module_value )
211- if i in indices_to_keep
212- ]
213- )
214- setattr (model , module_name , new_layers )
215- else :
216- # No layers from this structure needed, set to empty structure
217- if isinstance (module_value , nn .ModuleDict ):
218- setattr (model , module_name , nn .ModuleDict ())
219- elif isinstance (module_value , nn .ModuleList ):
220- setattr (model , module_name , nn .ModuleList ())
221- # Handle simple module attributes (e.g., "linear", "norm")
222- elif module_name not in modules_to_keep :
223- # Replace with Identity
224- setattr (model , module_name , nn .Identity ())
225-
226- stage = PipelineStage (
227- model ,
228- stage_idx ,
229- num_stages ,
230- device ,
231- group = pp_mesh .get_group ("pp" ),
232- )
233- return stage , model
234-
235- num_stages = len (module_names_per_stage )
236- stages = []
237- models = []
238-
239- schedule_class = get_schedule_class (pp_schedule )
240- style = (
241- "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
242- )
243-
244- def _get_stage_indices () -> tuple [int ]:
245- """
246- Compute the stage ids for the stages that will run on this pp rank
247- for either a looped or V style schedule
248- """
249- assert (
250- num_stages % pp_degree == 0
251- ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
252- stages_per_rank = num_stages // pp_degree
253- if style == "loop" :
254- return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
255- elif style == "v" :
256- assert (
257- stages_per_rank == 2
258- ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
259- stage_v_pairs = list (
260- zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
261- )
262- return stage_v_pairs [pp_rank ]
263-
264- for stage_idx in _get_stage_indices ():
265- module_names = module_names_per_stage [stage_idx ]
266- stage , model_chunk = _build_stage_from_modules (
267- stage_idx ,
268- module_names ,
269- num_stages ,
270- )
271- logger .info (
272- f"PP rank { pp_rank } is building stage_idx { stage_idx } "
273- f"with modules { module_names } "
274- )
275- stages .append (stage )
276- models .append (model_chunk )
277-
278- return stages , models
279-
28027
28128def pipeline_hf_transformers (
28229 model : nn .Module ,
@@ -355,7 +102,11 @@ def pipeline_hf_transformers(
355102 module_names_per_stage = job_config .parallelism .module_fqns_per_model_part
356103 if module_names_per_stage is None :
357104 module_names_per_stage = generate_llm_fqn_per_model_part (
358- num_virtual_stages , num_layers , input_weight , output_weight
105+ num_virtual_stages ,
106+ num_layers ,
107+ input_weight ,
108+ output_weight ,
109+ include_rotary_emb = True ,
359110 )
360111 for i , stage_ms in enumerate (module_names_per_stage ):
361112 logger .debug (f"Stage { i } : { stage_ms } " )
@@ -366,6 +117,7 @@ def pipeline_hf_transformers(
366117 job_config .parallelism .pipeline_parallel_schedule ,
367118 device ,
368119 module_names_per_stage ,
120+ use_identity_for_missing_modules = True ,
369121 )
370122
371123 # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
0 commit comments