-
Notifications
You must be signed in to change notification settings - Fork 615
Support gradient accumulation #1238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Maybe it would also make sense to rename |
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I suggest that we don't let train_step() be aware of data_iterator. Please see the detail comments.
Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync if FSDP is applied. We can raise an exception if DDP is used and accumulation_steps > 1 for now.
torchtitan/train.py
Outdated
| unwrapped_loss_fn = self.loss_fn | ||
|
|
||
| @functools.wraps(unwrapped_loss_fn) | ||
| def accumulated_loss_fn(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just modify build_loss_fn to take accmulation_steps to let the loss function decide the usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size and implicit grad_accu_steps, then we'll need to do another check & computation in the loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the wrapping functionality to torchtitan.components.loss, called it rescale_accumulated_loss. Not quite like what you wanted, but that way we can re-use the Trainer.gradient_accumulation_step value more easily.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this feature!
I left several comments. Please see if they make sense.
torchtitan/config_manager.py
Outdated
| loaded from this path instead of downloaded. | ||
| """ | ||
|
|
||
| batch_size: int = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let's call it local_batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a rename across the codebase whereever JobConfig.training.batch_size or --training.batch_size was used. Not sure how you'd like me to handle the compatibility breakage that this introduces.
torchtitan/train.py
Outdated
| if job_config.training.global_batch_size < 0: | ||
| job_config.training.global_batch_size = ( | ||
| job_config.training.batch_size * dp_degree | ||
| ) | ||
| assert job_config.training.global_batch_size > 0 | ||
| assert ( | ||
| job_config.training.global_batch_size | ||
| % (job_config.training.batch_size * dp_degree) | ||
| == 0 | ||
| ), ( | ||
| f"global batch size must be multiple of local batch size times " | ||
| f"data-parallel degree ({job_config.training.global_batch_size} " | ||
| f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | ||
| ) | ||
|
|
||
| self.gradient_accumulation_steps = job_config.training.global_batch_size // ( | ||
| job_config.training.batch_size * dp_degree | ||
| ) | ||
| assert self.gradient_accumulation_steps > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit comment
| if job_config.training.global_batch_size < 0: | |
| job_config.training.global_batch_size = ( | |
| job_config.training.batch_size * dp_degree | |
| ) | |
| assert job_config.training.global_batch_size > 0 | |
| assert ( | |
| job_config.training.global_batch_size | |
| % (job_config.training.batch_size * dp_degree) | |
| == 0 | |
| ), ( | |
| f"global batch size must be multiple of local batch size times " | |
| f"data-parallel degree ({job_config.training.global_batch_size} " | |
| f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | |
| ) | |
| self.gradient_accumulation_steps = job_config.training.global_batch_size // ( | |
| job_config.training.batch_size * dp_degree | |
| ) | |
| assert self.gradient_accumulation_steps > 0 | |
| global_batch_size = job_config.training.global_batch_size | |
| if global_batch_size < 0: | |
| global_batch_size = job_config.training.batch_size * dp_degree | |
| self.gradient_accumulation_steps = 1 | |
| else: | |
| assert global_batch_size > (job_config.training.batch_size * dp_degree) | |
| assert ( | |
| job_config.training.global_batch_size | |
| % (job_config.training.batch_size * dp_degree) | |
| == 0 | |
| ), ( | |
| f"global batch size must be multiple of local batch size times " | |
| f"data-parallel degree ({global_batch_size} " | |
| f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" | |
| ) | |
| self.gradient_accumulation_steps = global_batch_size // ( | |
| job_config.training.batch_size * dp_degree | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't really agree with not re-using the code that would become else case here, but can still change it to your recommendation. For now, I put the addition of the global_batch_size variable into its own commit, which probably already has the readability improvements that you'd like. Also added a comment in the if case that this global batch size results in 1 gradient accumulation step.
torchtitan/train.py
Outdated
|
|
||
| self.loss_fn = self.train_spec.build_loss_fn(job_config) | ||
|
|
||
| unwrapped_loss_fn = self.loss_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put the self.gradient_accumulation_steps derivation code right before here, to group gradient accum logic together as much as possible.
I understand that it is desirable to fail early on infeasible global batch size, even before parallelism and other heavy things are applied. But I'd suggest we prioritize readability. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fair! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this.
torchtitan/train.py
Outdated
|
|
||
| # Keep these variables local to shorten the code as these are | ||
| # the major variables that are used in the training loop. | ||
| def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we call it forward_backward_step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. By the way, if you'd prefer me to squash these changes into the previous commits, I'd be happy to clean up the commit chain.
torchtitan/train.py
Outdated
| model_parts = self.model_parts | ||
| world_mesh = self.world_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, maybe not worth keeping these two
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/components/metrics.py
Outdated
| ) | ||
| self.ntokens_since_last_log = 0 | ||
| self.data_loading_times = [] | ||
| self.accumulated_losses = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it represents a core training concept, rather than directly used for metrics logging, let's put this in Trainer, instead of MetricsProcessor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Also added the gradient_accumulation_steps attribute to the Trainer's dataclass attributes.
torchtitan/train.py
Outdated
| except StopIteration: | ||
| # If data runs out during gradient accumulation, that | ||
| # entire step will not be executed. | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of explicit return True, can we just call next and let the StopIteration exception propagate to train_step and catch over there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially had it implemented this way, but thought the try block would encapsulate too much code. If anything else raises a StopIteration, it would make debugging much more difficult. Therefore the minimization of the try scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step() simple without data dependency. So there is no other StopIteration() afaik. If there are other places actually raise the StopIteration, we should figure it out.
If we really want to avoid ambiguity , we can have a customized next(), like next_batch() which will raise a customized DataDepleteException().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's considerate. I think it's quite unlikely other places would also raise StopIteration? Maybe microbatching in pipeline parallel? But over there the number of microbatches should be fixed ahead of time.
Anyways, if you think we need to deal with this explicitly, we should catch the StopIteration exception, and raise a customized DataloaderStopIteration exception to be caught by caller, instead of return True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went with a combination of these suggestions; a Trainer.next_batch method basically just calls next(data_iterator), but catches and re-raises its StopIteration as a new DataloaderStopIteration.
torchtitan/train.py
Outdated
| self.step += 1 | ||
| self.gc_handler.run(self.step) | ||
| self.train_step(inputs, labels) | ||
| data_ran_out = self.train_step(data_iterator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can catch the StopIteration here and do different treatment on self.checkpointer.save in try vs. catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has been changed, but we now simply break in case of the DataloaderStopIteration to prevent the change to the checkpointing logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does change the general logic (e.g., torch_profiler and memory_profiler won't be stepped anymore) compared to the previous code, but is a bit nicer to read instead of adding an extra variable check in the while-query, IMO.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this PR doesn't change the parallelization, which is not correct. We will have to call set_requires_gradient_sync if FSDP is applied.
@fegin For background please see #292 (comment)
I think for us we don't want the potential memory overhead and code complexity, although it can save some communications which could've been hidden anyway.
torchtitan/train.py
Outdated
| unwrapped_loss_fn = self.loss_fn | ||
|
|
||
| @functools.wraps(unwrapped_loss_fn) | ||
| def accumulated_loss_fn(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way.
I think being more explicit about grad accumulation handling doesn't look bad.
Also if we go with explicit global_batch_size and implicit grad_accu_steps, then we'll need to do another check & computation in the loss function.
torchtitan/train.py
Outdated
| def train_step( | ||
| self, | ||
| data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], | ||
| ) -> bool | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just return bool and change all other returns to return False to keep the semantic consistent. This should be changed if we still keep the returning value as the design option. But I prefer try/catch. See the below response.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted this/refactored to try-catch solution as per other discussions. Return type is back to implicit None.
torchtitan/train.py
Outdated
| except StopIteration: | ||
| # If data runs out during gradient accumulation, that | ||
| # entire step will not be executed. | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer directly raise StopIteration and let the outer loop to catch. As mentioned in the above discussion, the original design is to keep train_step() simple without data dependency. So there is no other StopIteration() afaik. If there are other places actually raise the StopIteration, we should figure it out.
If we really want to avoid ambiguity , we can have a customized next(), like next_batch() which will raise a customized DataDepleteException().
|
The review order looks pretty confusing, lol. The summary of some big discussions:
cc., @tianyu-l |
|
hey @janEbert how about let's work a bit more on the PR. Sorry for the confusion in the reviews. I think we have agreed on the direction:
Please also add a test case in https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests.py |
@fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See pytorch#1238 (comment).)
|
I believe I have incorporated all the feedback. Let me know how you like the changes. FYI I'm currently on a conference and on vacation from Friday, so it would be great to get this done before Friday, even if I may only sporadically find time. :) |
@fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See pytorch#1238 (comment).)
|
Rebased because of |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks almost good! Please address final comments.
Also the addition of forward_backward_step breaks the FLUX model training.
Could you help refactor the train_step to forward_backward_step over there? Probably just
- remove the
optimizer.zero_grad - remove https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L152-L180
return loss
For the eval step https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train.py#L182
It should be done in Trainer.train(), but since we are not using grad accumulation in FLUX training, it is OK to leave it in forward_backward_step to accelerate landing of this PR, as long as CI tests pass. @wwwjn and I will work together on fixing it later.
tests/integration_tests.py
Outdated
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| # Default local batch size = 8, and `ngpu=2`, so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's explicitly specify local batch size as well, in case some future PR change the default without changing the test here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/config_manager.py
Outdated
| """ | ||
| The size of each pipeline parallel microbatch (default 1). | ||
| This value is used to compute the total number of microbatches by dividing batch_size with | ||
| This value is used to compute the total number of microbatches by dividing local batch_size with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| This value is used to compute the total number of microbatches by dividing local batch_size with | |
| This value is used to compute the total number of microbatches by dividing local_batch_size with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch! I didn't see the underscore on my dirty screen lol
torchtitan/train.py
Outdated
| class DataloaderStopIteration(StopIteration): | ||
| """An exception that indicates dataloader exhaustion.""" | ||
|
|
||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/train.py
Outdated
| try: | ||
| self.train_step(data_iterator) | ||
| except DataloaderStopIteration: | ||
| logger.info("Ran out of data; last step was canceled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| logger.info("Ran out of data; last step was canceled.") | |
| logger.warning("Ran out of data; last step was canceled.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchtitan/train.py
Outdated
|
|
||
| # Keep these variables local to shorten the code as these are | ||
| # the major variables that are used in the training loop. | ||
| def next_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function sounds less necessary, especially when we already have dataloader and batch_generator. Given how short it is, it seems not too bad just running the try-catch in train_step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, it makes the train_step look cleaner and it was nice to have it re-usable for the FLUX refactor. Does that change your mind? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking to patch the data iterator's __next__ method on-the-fly, to ensure the DataloaderStopIteration is raised, but didn't want to put too much black magic. It would require modifying the ParallelAwareDataloader.__iter__ method to apply the patch to the returned iterator. What do you think of that option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to keep the current implementation. Monkey patching is usually not a good idea. Also agree this function makes train_step cleaner.
Some future benefit, we may want to do data loader pipelining, which overlaps the to("cuda") with the computation. This function gives us a good place to implement it.
| job_config.training.local_batch_size * dp_degree | ||
| ) | ||
| assert self.gradient_accumulation_steps > 0 | ||
| self.loss_fn = rescale_accumulated_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a comment not a suggestion:
The code sounds to me assuming the loss function we use must perform a "mean" reduction, instead of "sum" also available in e.g. cross entropy loss.
But I believe this assumption is also made in pytorch DDP, FSDP, PP, and universally accepted as the default now. So I think it's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I added a docstring to the function to explicitly mention this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, CP also assumes mean. A docstring will be nice, thanks!
|
PTAL. |
Agree! The current change on FLUX side looks good to me. In the future I will also test grad accumulation w/ FLUX. Ideally in the future I will move |
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. There are some typing nits, but overall the implementation is clean.
torchtitan/train.py
Outdated
|
|
||
| def forward_backward_step( | ||
| self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we type the return value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
|
||
| def train_step( | ||
| self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, can we type the return value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This returns implicit None, so probably this one should remain like it is (i.e., don't add the -> None)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ye, it's minor but I think it is generally better to explicit type, even for None: https://peps.python.org/pep-0484/#using-none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't run CI because
This branch has conflicts that must be resolved
Would you please rebase?
Had one more comment on the next_batch function. See if you agree.
torchtitan/train.py
Outdated
|
|
||
| # Keep these variables local to shorten the code as these are | ||
| # the major variables that are used in the training loop. | ||
| def next_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think it's not necessary to create next_batch.
For the purpose of transforming the StopIteration exception, can we just do it in batch_generator? E.g. not doing for loop, but while True and try-catch
it was nice to have it re-usable for the FLUX refactor.
I think FLUX can reuse all train_step. For correctness right now, we can overload FluxTrainer.train_step() by calling super.train_step() and then do eval.
Some future benefit, we may want to do data loader pipelining, which overlaps the to("cuda") with the computation. This function gives us a good place to implement it.
For future benefit, let's make it only when the future comes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome idea!
|
PTAL. |
|
One other thing I noticed and changed just now, which was an artifact from earlier versions: we don't need to keep the |
Previously `int | None`. Makes it possible to obtain the automatic calculation of it when it has already been set in a TOML config.
@fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See pytorch#1238 (comment).)
I.e., a new `DataloaderStopIteration` that inherits from `StopIteration`. Accordingly, no longer return an optional `bool` to indicate depletion and adapt the remainder of the code to catch the new exception instead.
This concerns only renaming - `--training.batch_size` to `--training.local_batch_size` and - `job_config.training.batch_size` to `job_config.training.local_batch_size`.
I.e., the method in `Trainer`.
Instead use a new helper variable `global_batch_size` for all logic. Improves readability.
Improve readability.
These were only used in 1 or 2 locations each.
... from `MetricsProcessor`.
... toward `forward_backward_step` design.
We now raise the `DataloaderStopIteration` from inside the `batch_generator` generator method. `next_batch` can thus be removed as its only purpose at this point for raising the custom exception upon iterator exhaustion.
Move from dataclass attributes to method-local variable.
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making such a great work!
|
@janEbert Thank you very much for the elegant work!!! |
|
Very kind words, thank you! Thank you also for all the patience and great reviews! |
After #1238 landed, we could consolidate FLUX train_step() to reuse the main trainer's `train_step` function, by removing the `eval_step()`. We will replace eval_step() to be a `Validator` in the future to perform various validation methods.
First, the batched backward calculation is refactored into its own function. Then, gradient accumulation is implemented by moving the data iterator inside the
train_stepmethod and consuming data from it as necessary. I added some extra handling for non-infinite data iterators, but if you dislike that additional complexity, I can remove it to simplify the code.The feature is enabled by giving an additional
--training.global_batch_size, which has a sensible default of 1 gradient accumulation step (i.e., no actual accumulation).@tianyu-l thanks for the ping.