- 
                Notifications
    
You must be signed in to change notification settings  - Fork 31k
 
Generate: replace breaks by a loop condition #29662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
               | 
          ||
| if this_peer_finished and not synced_gpus: | ||
| break | ||
| this_peer_finished = unfinished_sequences.max() == 0 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous version is also a data-dependent control flow, so this change is for torch.compile readiness :)
| 
           FYI @zucchini-nlp (the stopping criteria solution did not preserve ZeRO stage 3 support)  | 
    
| 
           The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this - so much cleaner 🤩
        
          
                src/transformers/generation/utils.py
              
                Outdated
          
        
      | else: | ||
| if this_peer_finished: | ||
| return False | ||
| return True | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or actually, we can just do
| else: | |
| if this_peer_finished: | |
| return False | |
| return True | |
| return not this_peer_finished | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This solution can return False when synced_gpus is True and this_peer_finished is True, which is not intended -- this_peer_finished has to be True in all distributed devices when synced_gpus is True 🤗
Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
Pulled from the
torch.compile(..., fullgraph=True)draft PR: #29374It replaces the
breaksthat exit the endless generation loop with an equivalent function that returnsFalsewhen it should stop generating, while preserving ZeRO stage 3 support. It is not only an improvement in terms of code reuse, but also a hard requirement to enabletorch.compile(..., fullgraph=True):breakand data-dependent control flow is not supported.