File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change 77import torch .nn as nn
88from torchtitan .components .ft .config import FaultTolerance as FTConfig
99from torchtitan .distributed .pipeline_parallel import generate_llm_fqn_per_model_part
10+ from torchtitan .tools .logging import logger
1011
1112
1213def module_split (
@@ -98,7 +99,9 @@ def _build_fragment_from_modules(
9899 fragment_idx ,
99100 module_names ,
100101 )
101- print (f"building fragment_idx { fragment_idx } " f"with modules { module_names } " )
102+ logger .info (
103+ f"building fragment_idx { fragment_idx } " f"with modules { module_names } "
104+ )
102105 model_fragments .append (model_fragment )
103106
104107 return model_fragments
@@ -118,13 +121,14 @@ def fragment_llm(
118121
119122 if module_fqns_per_model_fragment == []:
120123 if ft_config .num_fragments == 1 :
124+ logger .info ("Created 1 model fragments" )
121125 return [model ]
122126
123127 module_fqns_per_model_fragment = generate_llm_fqn_per_model_part (
124128 ft_config .num_fragments , n_layers , input_weight , output_weight
125129 )
126130
127131 model_fragments = module_split (model , module_fqns_per_model_fragment )
128- print (f"Created { len (model_fragments )} model fragments" )
132+ logger . info (f"Created { len (model_fragments )} model fragments" )
129133
130134 return model_fragments
Original file line number Diff line number Diff line change 1616from torch .distributed ._composable .fsdp .fully_shard import FSDPModule
1717from torch .distributed .distributed_c10d import ReduceOp
1818from torchtitan .components .ft .config import FaultTolerance as FTConfig
19+ from torchtitan .tools .logging import logger
1920
2021if importlib .util .find_spec ("torchft" ) is not None :
2122 import torchft as ft
@@ -125,6 +126,9 @@ def maybe_semi_sync_training(
125126 assert (
126127 ft_manager ._manager is not None
127128 ), "FTManager must be enabled to use semi-sync training."
129+ logger .info (
130+ f"using fragment function to split model: { fragment_fn is not None } "
131+ )
128132 if semi_sync_method .lower () == "diloco" :
129133 if fragment_fn :
130134 model_parts = fragment_fn (model , ft_config , n_layers )
You can’t perform that action at this time.
0 commit comments