@@ -228,7 +228,6 @@ def generate_llm_fqn_per_model_part(
228228 num_layers : int ,
229229 input_weight : int = 1 ,
230230 output_weight : int = 1 ,
231- include_rotary_emb : bool = False ,
232231) -> list [list [str ]]:
233232 """
234233 Programmatically generates module names model part, focused on LLMs models.
@@ -238,7 +237,6 @@ def generate_llm_fqn_per_model_part(
238237 num_layers: Total number of transformer layers in the model
239238 input_weight: Weight for input modules (tok_embeddings) in layer calculation
240239 output_weight: Weight for output modules (norm + output) in layer calculation
241- include_rotary_emb: Whether to include rotary_emb in each model part
242240
243241 Returns:
244242 List of lists containing module names for each model part
@@ -253,10 +251,7 @@ def generate_llm_fqn_per_model_part(
253251 if num_stages == 1 :
254252 # Single stage gets everything
255253 layer_names = [f"layers.{ i } " for i in range (num_layers )]
256- result = [["tok_embeddings" ] + layer_names + ["norm" , "output" ]]
257- if include_rotary_emb :
258- result [0 ].append ("rotary_emb" )
259- return result
254+ return [["tok_embeddings" ] + layer_names + ["norm" , "output" ]]
260255
261256 # Calculate effective layers including weights
262257 num_effective_layers = num_layers + input_weight + output_weight
@@ -334,8 +329,6 @@ def generate_llm_fqn_per_model_part(
334329 stage_modules .append (f"layers.{ current_layer } " )
335330 current_layer += 1
336331
337- if include_rotary_emb :
338- stage_modules .append ("rotary_emb" )
339332 module_names_per_stage .append (stage_modules )
340333
341334 return module_names_per_stage
@@ -347,7 +340,6 @@ def pipeline_module_split(
347340 pp_schedule : str ,
348341 device : torch .device ,
349342 module_names_per_stage : list [list [str ]],
350- use_identity_for_missing_modules : bool = False ,
351343) -> tuple [list [PipelineStage ], list [nn .Module ]]:
352344 """
353345 This API creates pipeline stages based on specified module names for each stage.
@@ -369,8 +361,6 @@ def pipeline_module_split(
369361 - "layers.0", "layers.1" for specific transformer layers
370362 - "norm" for the final normalization layer
371363 - "output" for the output projection layer
372- use_identity_for_missing_modules: If True, replace missing modules with nn.Identity(),
373- otherwise replace with None
374364
375365 Returns:
376366 Tuple of (stages, models) where stages are PipelineStage objects and models are the
@@ -427,11 +417,8 @@ def _build_stage_from_modules(
427417 setattr (model , module_name , nn .ModuleList ())
428418 # Handle simple module attributes (e.g., "linear", "norm")
429419 elif module_name not in modules_to_keep :
430- # Replace with Identity or None based on configuration
431- replacement = (
432- nn .Identity () if use_identity_for_missing_modules else None
433- )
434- setattr (model , module_name , replacement )
420+ # Replace with None
421+ setattr (model , module_name , None )
435422
436423 stage = PipelineStage (
437424 model ,
0 commit comments