-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Variant] Add "variant" as input kwarg so to have better UX when downloading no_ema or fp16 weights #2305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -89,12 +89,12 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: | |||
return first_tuple[1].dtype | |||
|
|||
|
|||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]): | |||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could directly check the extension of checkpoint_file
to know whether or not to load with safetensors. That way we wouldn't have to pass in variant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some doc nits and typos.
``` | ||
|
||
Now all model components of the pipeline are stored in half-precision dtype. We can now save the | ||
pipeline under a `"fp16"` variant as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should standardize / recommend fp16
(that's the name of the old revision branch), or float16
(the torch type).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! Think it'll be difficult to nudge the community towards float16 given that we advertised ="fp16"
everywhere - also in terms of deprecating it'll be difficult
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamberman @patil-suraj wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Patrick here, I think using fp16
is pretty standard now. And I think it's okay to have the same name for revision and variant here as essentially variant is a better alternative to revision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I think we've stadardized on fp16
I think if we also want to have sets of equivalent variations that's cool too.
i.e. if the user passes "float16" as the variant, we use fp16 anyway and log letting them know
|
||
throws an Exception: | ||
``` | ||
OSError: Error no file named diffusion_pytorch_model.bin found in directory ./stable-diffusion-v1-45/vae since we **only** stored the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should throw a better error here, something like
"Error: You loaded the pipeline without a variant. Only found variants: fp16
, bf16
in the repository."
And similarly, when loading with a variant that isn't present.
"Error: You loaded the pipeline with variant: fp16
. Only found variants: bf16
, no_ema
in the repository."
We could do this as a follow up PR?
onnx_variant_filenames = set([f for f in variant_filenames if f.endswith(".onnx")]) | ||
onnx_model_filenames = set([f for f in model_filenames if f.endswith(".onnx")]) | ||
if len(onnx_variant_filenames) > 0 and onnx_model_filenames != onnx_variant_filenames: | ||
logger.warn( | ||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(onnx_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(onnx_model_filenames - onnx_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think this should deal with the .safetensors
extension, not onnx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes 100p
f"You are loading the variant {variant} from {pretrained_model_name_or_path} via `revision='{variant}'` even though you can load it via `variant=`{variant}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{variant}'` instead. For more information, please have a look at: ", | ||
FutureWarning, | ||
) | ||
except: # noqa: E722 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we only catch EntryNotFoundError
in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think pure except
is a bit safer here, just to be sure we don't miss a potential other type of error. Code will be removed somewhat soonish anyways though as it's just to maintain (future) deprecated behavior.
revision=revision, | ||
) | ||
warnings.warn( | ||
f"You are loading the variant {variant} from {pretrained_model_name_or_path} via `revision='{variant}'` even though you can load it via `variant=`{variant}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{variant}'` instead. For more information, please have a look at: ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this log message incomplete? "please have a look at: "
< I think something goes at the end here?
if revision in DEPRECATED_REVISION_ARGS and version.parse( | ||
version.parse(__version__).base_version | ||
) >= version.parse("0.15.0"): | ||
variant = _add_variant(weights_name, revision) | ||
|
||
try: | ||
model_file = hf_hub_download( | ||
pretrained_model_name_or_path, | ||
filename=weights_name, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
resume_download=resume_download, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
user_agent=user_agent, | ||
subfolder=subfolder, | ||
revision=revision, | ||
) | ||
warnings.warn( | ||
f"You are loading the variant {variant} from {pretrained_model_name_or_path} via `revision='{variant}'` even though you can load it via `variant=`{variant}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{variant}'` instead. For more information, please have a look at: ", | ||
FutureWarning, | ||
) | ||
except: # noqa: E722 | ||
warnings.warn( | ||
f"You are loading the variant {variant} from {pretrained_model_name_or_path} via `revision='{variant}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{variant}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.", | ||
FutureWarning, | ||
) | ||
model_file = None | ||
else: | ||
# Load from URL or cache if already cached | ||
model_file = hf_hub_download( | ||
pretrained_model_name_or_path, | ||
filename=weights_name, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
resume_download=resume_download, | ||
local_files_only=local_files_only, | ||
use_auth_token=use_auth_token, | ||
user_agent=user_agent, | ||
subfolder=subfolder, | ||
revision=revision, | ||
) | ||
return model_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the logic here. I think the second error message about "main" not having a file seems to indicate we're checking the main branch but we still use the passed in revision to check a branch? Maybe I'm missing something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so essentially we should deprecate the behavior of doing revision="fp16"
. Now we might have cases where this works because the branch still exists, but we do want to tell the user that this behavior is deprecated, so we throw a warning.
Now there are two possibilities:
- 1.) There is already a "fp16" variant file on the main branch -> in this case the user can be directly guided to the new usage
- 2.) There is not yet a "fp16" variant file. In this case we probably should do it manually, therefore I'm trying to have the user open an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're 100% right @williamberman - the code was bad here. Refactored it, should be better now I hope
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Will Berman <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
…loading no_ema or fp16 weights (huggingface#2305) * [Variant] Add variant loading mechanism * clean * improve further * up * add tests * add some first tests * up * up * use path splittetx * add deprecate * deprecation warnings * improve docs * up * up * up * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * correct code format * fix warning * finish * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Update docs/source/en/using-diffusers/loading.mdx Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * correct loading docs * finish --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Will Berman <[email protected]>
…loading no_ema or fp16 weights (huggingface#2305) * [Variant] Add variant loading mechanism * clean * improve further * up * add tests * add some first tests * up * up * use path splittetx * add deprecate * deprecation warnings * improve docs * up * up * up * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * correct code format * fix warning * finish * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Update docs/source/en/using-diffusers/loading.mdx Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * correct loading docs * finish --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Will Berman <[email protected]>
…loading no_ema or fp16 weights (huggingface#2305) * [Variant] Add variant loading mechanism * clean * improve further * up * add tests * add some first tests * up * up * use path splittetx * add deprecate * deprecation warnings * improve docs * up * up * up * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * correct code format * fix warning * finish * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * Update docs/source/en/using-diffusers/loading.mdx Co-authored-by: Suraj Patil <[email protected]> * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * correct loading docs * finish --------- Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Will Berman <[email protected]>
This PR adds a "variant" keyword argument so that model variations can be better stored on the "main" branch.
Important: See discussion here: #1764 also
It's the mirror of huggingface/transformers#21332 for
diffusers
.Make sure you're using
transformers
on "main" when trying out the following:These commands will load the respective variants from: https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main
Important:
local_files_only=True
To achieve this the ignore and allow patterns logic is refactored and made simpler, more precise.
Please have a look at the tests for more details.
Deprecation Cycle
Note that while we should merge this PR we probably need to wait 1,2 releases until we add the model variants to popular repos such as
stable-diffusion-v1-4
,v1-5
,v2-0
andv2-1
asdiffusers < 0.13.0dev0
would otherwise downloads GB of unused models. That's why there are some "in-the-future" deprecation warnings in the PR here.Stats about model repos having variants.
We only have ~55 public models (from 3800 models) that use branches as variations, so I think we could relatively easily transition all of those in ~1 month. Here the list:
Final TODOs
"revision="fp16"
in the smoothest way possible