-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[From pretrained] Speed-up loading from cache #2515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[From pretrained] Speed-up loading from cache #2515
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Still need to add tests, but apart from this everything should be good! |
@@ -129,3 +134,116 @@ def create_model_card(args, model_name): | |||
|
|||
card_path = os.path.join(args.output_dir, "README.md") | |||
model_card.save(card_path) | |||
|
|||
|
|||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all the added files are mostly copied from transformers
with some small changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to follow, but I think variants are not handled properly. I made a couple of other minor comments.
Another observation: if we load individual components separately (the |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a few comments on the implementation. I think it could be simplified (we do not need the same complexity as in transformers
and I also agree with @williamberman that revision
should be used instead of having a duplicate _commit_hash
.
src/diffusers/configuration_utils.py
Outdated
|
||
commit_hash = extract_commit_hash(config_file) | ||
config_dict["_commit_hash"] = commit_hash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I think in the future it would be nice if we had an option for the hub library to return additional metadata (including the commit hash) when we download a file. It looks like in the source, the commit hash is manually read out of the refs folder or the first metadata request from the hub. Relying on the structure of the local cache is ok but if we already retrieve the information in the function, would be nice to just explicitly return it.
Sorry if I already said something similar to this, can't remember if I did :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big PR 😅 Looks good, will run some additional tests looking for edge cases.
if return_cached_folder: | ||
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, why is this not desired anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because if one wants the cached_folder one can just split:
from_pretrained(....)
into
cached_folder = download(...)
from_pretrained(cached_folder)
|
||
<Tip> | ||
|
||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link does not exist. Should it point to https://huggingface.co/docs/diffusers/installation? There's no offline-mode section there, just a note about disabling telemetry.
revision=revision, | ||
) | ||
user_agent["pretrained_model_name"] = pretrained_model_name | ||
send_telemetry("pipelines", library_name="diffusers", library_version=__version__, user_agent=user_agent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this is async, which is cool. However, wouldn't it be possible to send the required data as part of the model_info
call, which needs to download info anyway? (I guess it's not possible, I don't know how data collection works in the backend)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Wauplin
Linking to a comment I made in a conversation I didn't realize was resolved: #2515 (comment) /cc @Wauplin, do you think we should address that? |
# load config | ||
config, unused_kwargs = cls.load_config( | ||
config_path, | ||
cache_dir=cache_dir, | ||
return_unused_kwargs=True, | ||
force_download=force_download, | ||
resume_download=resume_download, | ||
proxies=proxies, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
revision=revision, | ||
subfolder=subfolder, | ||
device_map=device_map, | ||
user_agent=user_agent, | ||
**kwargs, | ||
) | ||
commit_hash = config.pop("_commit_hash", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we ever expect the commit hash to be part of the pipeline config, so I think it should be an explicit return value
# load config | |
config, unused_kwargs = cls.load_config( | |
config_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
device_map=device_map, | |
user_agent=user_agent, | |
**kwargs, | |
) | |
commit_hash = config.pop("_commit_hash", None) | |
# load config | |
config, unused_kwargs, commit_hash = cls.load_config( | |
config_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
device_map=device_map, | |
user_agent=user_agent, | |
return commit_hash=True, | |
**kwargs, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair!
@@ -829,6 +808,9 @@ def _get_model_file( | |||
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") | |||
): | |||
try: | |||
if commit_hash is not None and revision is None: | |||
revision = commit_hash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will cause the warning logs in this function to print the commit hash instead of the revision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
# _commit_hash | ||
config.pop("_commit_hash", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, if commit hash becomes a return value instead of added on to config, we don't have to do this. This could be a bit confusing for someone reading for the first time
tests/test_modeling_common.py
Outdated
def tearDown(self): | ||
# clean up the VRAM after each test | ||
super().tearDown() | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
import diffusers | ||
|
||
diffusers.utils.import_utils._safetensors_available = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def tearDown(self): | |
# clean up the VRAM after each test | |
super().tearDown() | |
gc.collect() | |
torch.cuda.empty_cache() | |
import diffusers | |
diffusers.utils.import_utils._safetensors_available = True |
Tied to https://github.com/huggingface/diffusers/pull/2515/files#r1131591497
I also don't think we need any of the gc/cuda caching code in the utils test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still want to make sure diffusers.utils.import_utils._safetensors_available = True
is done in case the test fails
tests/test_pipelines.py
Outdated
def test_one_request_upon_cached(self): | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
with requests_mock.mock(real_http=True) as m: | ||
DiffusionPipeline.download( | ||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname | ||
) | ||
|
||
download_requests = [r.method for r in m.request_history] | ||
assert download_requests.count("HEAD") == 16, "15 calls to files + send_telemetry" | ||
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json" | ||
assert ( | ||
len(download_requests) == 33 | ||
), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json" | ||
|
||
with requests_mock.mock(real_http=True) as m: | ||
DiffusionPipeline.download( | ||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname | ||
) | ||
|
||
cache_requests = [r.method for r in m.request_history] | ||
assert cache_requests.count("HEAD") == 1, "send_telemetry is only HEAD" | ||
assert cache_requests.count("GET") == 1, "model info is only GET" | ||
assert ( | ||
len(cache_requests) == 2 | ||
), "We should call only `model_info` to check for _commit hash and `send_telemetry`" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beautiful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few nits but looks g2g. I wasn't super thorough on going through the pipeline loading logic and am more so relying on tests still passing.
My understanding of the high level is just that we need to pass a commit hash to hf_hub_download because a commit hash + the expected file already being in the cache just early returns the file path in the cache
Co-authored-by: Pedro Cuenca <[email protected]>
Activate the special | ||
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this | ||
method in a firewalled environment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Activate the special | |
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this | |
method in a firewalled environment. | |
Activate the special | |
["offline-mode"](https://huggingface.co/docs/diffusers/installation#notice-on-telemetry-logging) to use this | |
method in a firewalled environment. |
URL was wrong.
But still, unless I'm not understanding this tip correctly the doc is not related to offline mode or being behind a firewall. Maybe replace with something like:
Activate the special | |
["offline-mode"](https://huggingface.co/diffusers/installation.html#notice-on-telemetry-logging) to use this | |
method in a firewalled environment. | |
Use the `proxies` arg if you are in a firewalled environment, or `local_files_only` for full offline mode, | |
which requires the pipeline to be cached locally. Please, refer to | |
[these notes](https://huggingface.co/docs/diffusers/installation#notice-on-telemetry-logging) to | |
disable all telemetry logging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's probably just a bad copy-paste from transformers
that we should delete
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply code suggestions * finish --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply code suggestions * finish --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * apply code suggestions * finish --------- Co-authored-by: Lucain <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
As described in #2514 we currently ping the Hub too often when all files are already cached.
This PR makes sure that for every download the Hub is only pinged exactly once. All the functionality is copied from transformers (thanks @sgugger)
UPDATE:
New benchmark of #2514 (comment) reveals the speed-up:
Loading a single cached model:
Loading a single cached pipeline:
As mentioned by @pcuenca, in the previous version it wasn't possible to have certainty that the correct model variants were downloaded by just doing one HEAD call. So we change the HEAD call to a single GET call that retrieves all info about the pipeline and can then load the pipeline.
Note 1: This PR requires the new
huggingface_hub >= 0.13.0
as we're using the new telemetry function.Note 2:
DiffusionPipeline.from_pretrained
is refactored in this PR and should now be much more readable.🚨🚨 Make sure to update
huggingface_hub
to0.13.0
🚨🚨