@@ -146,66 +146,6 @@ def forward_backward_step(
146146
147147 return loss
148148
149- def train_step (
150- self , data_iterator : Iterable [tuple [dict [str , torch .Tensor ], torch .Tensor ]]
151- ):
152- input_dict , labels = next (data_iterator )
153-
154- self .optimizers .zero_grad ()
155-
156- # Keep these variables local to shorten the code as these are
157- # the major variables that are used in the training loop.
158- model_parts = self .model_parts
159- assert len (self .model_parts ) == 1
160- # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not
161- model = self .model_parts [0 ]
162-
163- world_mesh = self .world_mesh
164- parallel_dims = self .parallel_dims
165-
166- loss = self .forward_backward_step (input_dict , labels )
167-
168- dist_utils .clip_grad_norm_ (
169- [p for m in model_parts for p in m .parameters ()],
170- self .job_config .training .max_norm ,
171- foreach = True ,
172- pp_mesh = self .world_mesh ["pp" ] if parallel_dims .pp_enabled else None ,
173- )
174- self .checkpointer .maybe_wait_for_staging ()
175- self .optimizers .step ()
176- self .lr_schedulers .step ()
177-
178- # log metrics
179- if not self .metrics_processor .should_log (self .step ):
180- return
181-
182- if (
183- parallel_dims .dp_replicate_enabled
184- or parallel_dims .dp_shard_enabled
185- or parallel_dims .cp_enabled
186- ):
187- loss = loss .detach ()
188- ft_pg = self .ft_manager .replicate_pg if self .ft_manager .enabled else None
189- global_avg_loss , global_max_loss = (
190- dist_utils .dist_mean (loss , world_mesh ["dp_cp" ], ft_pg ),
191- dist_utils .dist_max (loss , world_mesh ["dp_cp" ], ft_pg ),
192- )
193- else :
194- global_avg_loss = global_max_loss = loss .item ()
195-
196- self .metrics_processor .log (self .step , global_avg_loss , global_max_loss )
197-
198- # Evaluate the model during training
199- if (
200- self .step % self .job_config .eval .eval_freq == 0
201- or self .step == self .job_config .training .steps
202- ):
203- model .eval ()
204- # We need to set reshard_after_forward before last forward pass.
205- # So the model wieghts are sharded the same way for checkpoint saving.
206- self .eval_step ()
207- model .train ()
208-
209149 def eval_step (self , prompt : str = "A photo of a cat" ):
210150 """
211151 Evaluate the Flux model.
@@ -247,6 +187,23 @@ def eval_step(self, prompt: str = "A photo of a cat"):
247187 if isinstance (module , FSDPModule ):
248188 module .reshard ()
249189
190+ def train_step (
191+ self , data_iterator : Iterable [tuple [dict [str , torch .Tensor ], torch .Tensor ]]
192+ ):
193+ super ().train_step (data_iterator )
194+
195+ # Evaluate the model during training
196+ if (
197+ self .step % self .job_config .eval .eval_freq == 0
198+ or self .step == self .job_config .training .steps
199+ ):
200+ model = self .model_parts [0 ]
201+ model .eval ()
202+ # We need to set reshard_after_forward before last forward pass.
203+ # So the model wieghts are sharded the same way for checkpoint saving.
204+ self .eval_step ()
205+ model .train ()
206+
250207
251208if __name__ == "__main__" :
252209 init_logger ()
0 commit comments