-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Core] support saving and loading of sharded checkpoints #7830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always delightful to deal with the from_pretrained code ;)
I don't really have any bigger comments, as this should hopefully work well since it's based on the transformers implementation. Only some smaller comments.
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work @sayakpaul ! Left a suggestion (not a blocker, we can do it afterwards if needed) ! No major comments since @BenjaminBossan did a very thorough review already !
Yeah. @yiyixuxu would be the final approver here :) |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR!!
I left some comments and questions :)
src/diffusers/configuration_utils.py
Outdated
| revision = kwargs.pop("revision", None) | ||
| _ = kwargs.pop("mirror", None) | ||
| subfolder = kwargs.pop("subfolder", None) | ||
| subfolder = kwargs.pop("subfolder", None) or "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we handle it where it fails then
| subfolder, |
we would only need to change one place, no?
src/diffusers/utils/hub_utils.py
Outdated
| raise EnvironmentError( | ||
| f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." | ||
| ) | ||
| # This should correspond to a shard index file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to return something different when we can't find the shard index file?
can we do
try:
model_file = _get_model_file(...)
...
except ...
model_file = NoneThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this question I still have: why do we need to return None when we can't find a shard index file? vs for any other file we get find we raise errors -
where in the code is this needed?
|
@yiyixuxu do the recent changes work for you? (I have run the tests) |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
I have one quetions! the rest look good to me
src/diffusers/utils/hub_utils.py
Outdated
| raise EnvironmentError( | ||
| f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." | ||
| ) | ||
| # This should correspond to a shard index file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this question I still have: why do we need to return None when we can't find a shard index file? vs for any other file we get find we raise errors -
where in the code is this needed?
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nevermind - i got confused in my last review!
good to merge!
|
Yay! Great job @sayakpaul ! 🎉 |
* feat: support saving a model in sharded checkpoints. * feat: make loading of sharded checkpoints work. * add tests * cleanse the loading logic a bit more. * more resilience while loading from the Hub. * parallelize shard downloads by using snapshot_download()/ * default to a shard size. * more fix * Empty-Commit * debug * fix * uality * more debugging * fix more * initial comments from Benjamin * move certain methods to loading_utils * add test to check if the correct number of shards are present. * add a test to check if loading of sharded checkpoints from the Hub is okay * clarify the unit when passed as an int. * use hf_hub for sharding. * remove unnecessary code * remove unnecessary function * lucain's comments. * fixes * address high-level comments. * fix test * subfolder shenanigans./ * Update src/diffusers/utils/hub_utils.py Co-authored-by: Lucain <[email protected]> * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * remove _huggingface_hub_version as not needed. * address more feedback. * add a test for local_files_only=True/ * need hf hub to be at least 0.23.2 * style * final comment. * clean up subfolder. * deal with suffixes in code. * _add_variant default. * use weights_name_pattern * remove add_suffix_keyword * clean up downloading of sharded ckpts. * don't return something special when using index.json * fix more * don't use bare except * remove comments and catch the errors better * fix a couple of things when using is_file() * empty --------- Co-authored-by: Lucain <[email protected]>
What does this PR do?
Follow-up of #6396.
This PR adds support for saving a big model's state dict into multiple shards for efficient portability and loading. Adds support for loading the sharded checkpoints, too.
This is much akin to handling big models like T5XXL.
Also, added a nice test to ensure the models that have
_no_split_modulesspecified can be sharded and loaded back to perform inference ensuring numerical assertions.Here's a real use-case. Consider this
Transformer2DModelcheckpoint: https://huggingface.co/sayakpaul/actual_bigger_transformer/.It was serialized like so:
As we can see from the Hub repo that its state dict is sharded. To perform with the model, all we have to do is this:
I haven't purposefully haven't added documentation because all of this will become useful once we use this in the context of a full-fledged pipeline execution (up next) :)