-
Notifications
You must be signed in to change notification settings - Fork 615
[Scheduler] Add support for cosine and wsd scheduler #938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Here is running examples for three schedulers. bash run_train.sh --optimizer.scheduler=linear \
--training.steps=50 \
--training.warmup_steps=5 \
--optimizer.min_lr_ratio=0.1bash run_train.sh --optimizer.scheduler=cosine \
--training.steps=50 \
--training.warmup_steps=5 \
--optimizer.min_lr_ratio=0.1bash run_train.sh --optimizer.scheduler=wsd \
--training.steps=50 \
--training.warmup_steps=5 \
--optimizer.min_lr_ratio=0.1 |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I think the three can be unified.
We only need to define warmup_ratio, decay_ratio, lr_decay_type (among linear sqrt cosine), and maybe lr_min (calling it ratio again would sound confusing, we can explain in the helper message it's a ratio), to achieve everything.
We can explain in the helper message, or in a doc, how to use them to achieve various combinations, including the three you explicitly wrote today.
|
Hi @yzhangcs - thanks for the PR. That said, I believe the WSD may be useful in some cases (MSFT had a paper on this a long time ago, and I used it back in AI competitions), and the 1 sqrt seems to be newer work than the above. Also, we used to display the lr but it was removed b/c most people didn't need it in the display...thus would recommend that this display is optional/configurable. |
|
@tianyu-l Hi, just updated the stuffs you mentioned. |
|
@lessw2020 Thank you for your comments, I'm wondering if it's ok to add some hints and paper link https://arxiv.org/abs/2310.07831 in config_manager?
Agreed. I just add the display to make sure the decay is right. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some inline comments. Below are some general comments.
Regarding file length/complexity:
I think we should separate things into a new file called lr_scheduler.py. This can be in a separate PR.
Regarding warmpup steps:
Since warmup behavior is closer to lr_scheduler, should we move training.warmup_steps also to the optimizer section (or even consider creating an lr_scheduler section)?
Regarding logging lr rates:
I think we should still log lr to TensorBoard / WandB, maybe after the #945
It can be called from a new get_lr function of LRSchedulersContainer, so that people can modify/inherit it to adapt to desired behaviors.
|
@tianyu-l Hello, I've just updated this PR based on your suggestions. Could you please review it again to see if there's anything I might have missed? |
|
Checks for bash run_train.sh --scheduler.warmup_steps=4 --scheduler.decay_ratio=0.9 --scheduler.decay_type=linear --training.steps=40bash run_train.sh --scheduler.warmup_steps=4 --scheduler.decay_ratio=0.9 --scheduler.decay_type=cosine --training.steps=40bash run_train.sh --scheduler.warmup_steps=4 --scheduler.decay_ratio=0.9 --scheduler.decay_type=sqrt --training.steps=40Checks for bash run_train.sh --scheduler.warmup_steps=4 --scheduler.decay_ratio=0.2 --scheduler.decay_type=linear --training.steps=40 |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! I suggest naming scheduler to lr_scheduler. Also please do a final rebase.
|
@tianyu-l Hi, thank you for the feedbacks. Just fixed them. |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Could you rebase to latest torchtitan main? I'm seeing changes in merged #945 as part of this PR (e.g. train.py)
…ts (#936) Two very minor changes required by Meta legal as part of adding two new datasets. 1 - License verbiage update in readme 2 - copyright header change in BSD-License.
…940) * people asks about the FSDP2 equivalance of no_sync, that's `set_requires_gradient_sync` * ignored_params is recently implemented. people start using it already. update the doc
This is similar in spirit to [PR_944](#944) (cc @lkhphuc) but takes a slightly different approach. Problem - users that default turn on PP training will get -1 for their loss. This is b/c by default, rank 0 is the only one logged. However, for *most* PP schedules, the loss is output on the last rank. Thus, users see -1 for loss and it's a bad/confusing experience. This PR adds a check to review both the current PP schedule (b/c for VBlocks, loss is returned on 0) and if it is a last rank loss schedule, then it checks that the first rank of the last stage is visible in the LOG_RANK environment variable. If not, it warns the user, using Red for the warning if color is enabled, and highlights the rank they should add in yellow: <img width="1236" alt="Screenshot 2025-03-07 at 11 51 46 AM" src="https://github.com/user-attachments/assets/02b18870-90bb-4cfb-89c1-3e92d2fb9bfb" /> Note that I attempted to then modify the LOG_RANK to add the missing last rank...but it has no effect. This is b/c the --log_rank_filter passed into torchrun is fixed and thus the env has no effect. We can fix this by moving to our own filtering via python log filtering (thanks to @d4l3k for this idea) and then it would auto-update. The tradeoff is that we have to init distributed first (to understand the ranks) meaning that at launch, there's a bit of delay before the first logging. From there, then NCCL warnings are not suppressed b/c they are emitted from .cpp file vs torchrun filtering controls that...so we get some additional console spam. This PR thus sticks to a simple warning with Red highlight (assuming color is on) and provide the user how to fix.
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #945 MetricsLogger should be a component as its role is similar to CheckpointManager, which provides some functions and has its own states. More importantly, users may want to customize the metrics. Make it a component and can be customized through TrainSpec. Change the name of `MetricsLogger` to `MetricsProcessor` as it not only log but also process metrics.
|
@tianyu-l Hey, just wanted to check if everything looks good on your end. I’m still getting familiar with rebase, so I want to make sure I did it correctly. |
|
@yzhangcs Could you help resolve it, e.g. setting Also the warning is called every single LR scheduler step. Let's move the check outside of this function and do it only once when setting up the scheduler function. |
|
@tianyu-l sure, I will create a new pr to fix it. |
### What does this PR do? Fix some minor issues in PR #938 1. Fix the `decay_ratio` in [debug_model.toml](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/train_configs/debug_model.toml), ensuing that `warmup_stable_steps` > `warmup_steps` 2. Make sure `warmup_stable_steps` is rounded to an integer 3. Move lr check into `JobConfig`
### What does this PR do? This PR introduces support for cosine and WSD schedulers. The cosine scheduler is widely used in LLM training (e.g., Pythia, OLMo, Llama), while the WSD scheduler, introduced by [MiniCPM](https://arxiv.org/abs/2404.06395), features a three-stage learning rate schedule: warmup, stable, and decay. The stable stage keeps the learning rate constant, beneficial for continual pretraining and flexible training budget adjustments. ### Why this PR is necessary The cosine scheduler is a standard for LLM training, and the WSD scheduler addresses specific needs like stable learning rates during training. This PR also supports three decay types: linear (MiniCPM's original approach), cosine (used in [hf transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L428)), and 1-sqrt (found optimal in [this paper](https://arxiv.org/html/2408.11029v1)). These additions provide flexibility and improved performance for diverse training scenarios. --------- Co-authored-by: Less Wright <[email protected]> Co-authored-by: Wei (Will) Feng <[email protected]> Co-authored-by: Chien-Chin Huang <[email protected]>
### What does this PR do? Fix some minor issues in PR pytorch#938 1. Fix the `decay_ratio` in [debug_model.toml](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/train_configs/debug_model.toml), ensuing that `warmup_stable_steps` > `warmup_steps` 2. Make sure `warmup_stable_steps` is rounded to an integer 3. Move lr check into `JobConfig`
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](pytorch#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](pytorch#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](pytorch#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](pytorch#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](pytorch#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](pytorch#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.








What does this PR do?
This PR introduces support for cosine and WSD schedulers. The cosine scheduler is widely used in LLM training (e.g., Pythia, OLMo, Llama), while the WSD scheduler, introduced by MiniCPM, features a three-stage learning rate schedule: warmup, stable, and decay. The stable stage keeps the learning rate constant, beneficial for continual pretraining and flexible training budget adjustments.
Why this PR is necessary
The cosine scheduler is a standard for LLM training, and the WSD scheduler addresses specific needs like stable learning rates during training. This PR also supports three decay types: linear (MiniCPM's original approach), cosine (used in hf transformers), and 1-sqrt (found optimal in this paper).
These additions provide flexibility and improved performance for diverse training scenarios.