Skip to content

Commit 3ac950c

Browse files
authored
Move eval to main train loop and consolidate FLUX train_step (#1266)
After #1238 landed, we could consolidate FLUX train_step() to reuse the main trainer's `train_step` function, by removing the `eval_step()`. We will replace eval_step() to be a `Validator` in the future to perform various validation methods.
1 parent 8570415 commit 3ac950c

File tree

2 files changed

+23
-65
lines changed

2 files changed

+23
-65
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ To accelerate contributions to and innovations around torchtitan, we are hosting
6060
7. DDP and HSDP
6161
8. [TorchFT](https://github.com/pytorch/torchft) integration
6262
9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
63-
10. Flexible learning rate scheduler (warmup-stable-decay)
64-
11. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
65-
12. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
66-
13. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/)
67-
14. [Helper scripts](scripts/) to
63+
10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument in configuration
64+
11. Flexible learning rate scheduler (warmup-stable-decay)
65+
12. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
66+
13. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
67+
14. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/)
68+
15. [Helper scripts](scripts/) to
6869
- download tokenizers from Hugging Face
6970
- convert original Llama 3 checkpoints into the expected DCP format
7071
- estimate FSDP/HSDP memory usage without materializing the model

torchtitan/experiments/flux/train.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -146,66 +146,6 @@ def forward_backward_step(
146146

147147
return loss
148148

149-
def train_step(
150-
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
151-
):
152-
input_dict, labels = next(data_iterator)
153-
154-
self.optimizers.zero_grad()
155-
156-
# Keep these variables local to shorten the code as these are
157-
# the major variables that are used in the training loop.
158-
model_parts = self.model_parts
159-
assert len(self.model_parts) == 1
160-
# explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not
161-
model = self.model_parts[0]
162-
163-
world_mesh = self.world_mesh
164-
parallel_dims = self.parallel_dims
165-
166-
loss = self.forward_backward_step(input_dict, labels)
167-
168-
dist_utils.clip_grad_norm_(
169-
[p for m in model_parts for p in m.parameters()],
170-
self.job_config.training.max_norm,
171-
foreach=True,
172-
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
173-
)
174-
self.checkpointer.maybe_wait_for_staging()
175-
self.optimizers.step()
176-
self.lr_schedulers.step()
177-
178-
# log metrics
179-
if not self.metrics_processor.should_log(self.step):
180-
return
181-
182-
if (
183-
parallel_dims.dp_replicate_enabled
184-
or parallel_dims.dp_shard_enabled
185-
or parallel_dims.cp_enabled
186-
):
187-
loss = loss.detach()
188-
ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
189-
global_avg_loss, global_max_loss = (
190-
dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
191-
dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
192-
)
193-
else:
194-
global_avg_loss = global_max_loss = loss.item()
195-
196-
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
197-
198-
# Evaluate the model during training
199-
if (
200-
self.step % self.job_config.eval.eval_freq == 0
201-
or self.step == self.job_config.training.steps
202-
):
203-
model.eval()
204-
# We need to set reshard_after_forward before last forward pass.
205-
# So the model wieghts are sharded the same way for checkpoint saving.
206-
self.eval_step()
207-
model.train()
208-
209149
def eval_step(self, prompt: str = "A photo of a cat"):
210150
"""
211151
Evaluate the Flux model.
@@ -247,6 +187,23 @@ def eval_step(self, prompt: str = "A photo of a cat"):
247187
if isinstance(module, FSDPModule):
248188
module.reshard()
249189

190+
def train_step(
191+
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
192+
):
193+
super().train_step(data_iterator)
194+
195+
# Evaluate the model during training
196+
if (
197+
self.step % self.job_config.eval.eval_freq == 0
198+
or self.step == self.job_config.training.steps
199+
):
200+
model = self.model_parts[0]
201+
model.eval()
202+
# We need to set reshard_after_forward before last forward pass.
203+
# So the model wieghts are sharded the same way for checkpoint saving.
204+
self.eval_step()
205+
model.train()
206+
250207

251208
if __name__ == "__main__":
252209
init_logger()

0 commit comments

Comments
 (0)