-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
staleIssues that haven't received updatesIssues that haven't received updates
Description
Since we were considering adding an option like single_file_format
to save_pretrained()
of DiffusionPipeline
, it makes sense to have something similar in from_pretrained()
to have better feature parity.
We currently support loading single file checkpoints in DiffusionPipeline
via from_single_file()
. Some examples below:
from diffusers import StableDiffusionPipeline
# Download pipeline from huggingface.co and cache.
pipeline = StableDiffusionPipeline.from_single_file(
"https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
)
# Download pipeline from local file
# file is downloaded under ./v1-5-pruned-emaonly.ckpt
pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
# Enable float16 and move to GPU
pipeline = StableDiffusionPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
torch_dtype=torch.float16,
)
pipeline.to("cuda")
(Taken from the docs here)
Proposed API design
Calling from_pretrained()
on a DiffusionPipeline
requires users to mandatorily pass pretrained_model_name_or_path
, which can be a repo id on the Hub or a local directory containing checkpoints in the diffusers
format.
(Docs)
Now, if we want to add support for loading a compatible single file checkpoint in from_pretrained()
, we could have an API like so:
from diffusers import DiffusionPipeline
repo_id = "WarriorMama777/OrangeMixs"
pipe = DiffusionPipeline.from_pretrained(repo_id, weight_name="Models/AbyssOrangeMix/AbyssOrangeMix.safetensors")
- Like before,
repo_id
could either be an actual repo id on the Hub or a local directory. weight_name
can either be just the filename of the single file checkpoint to be loaded or the relative path to the checkpoint (w.r.t the underlying repo / directory).- When
weight_name
is provided infrom_pretrained()
:- We immediately check if the file exists in the repository or the directory and flag an error if necessary in case it's not found.
- Once it's checked, we hit the codepath that we're hitting currently when using
from_single_file()
. Logic to do that should be completely separated as a utility and should not come intofrom_pretrained()
. We can just call the utility fromfrom_pretrained()
. - How can we detect errors here as early as possible? What if the checkpoint is not compatible or doesn't have all the components we need (what if the
vae
or any other component is missing)? Is there any robust way?
- Once this support is foolproof, we can start deprecating the use of
from_single_file()
.
Some thoughts
- I don't think this is a very new design. Users are already familiar with
weight_name
and how it's to be used throughload_lora_weights()
(which is quite popular at this point IMO). - I think we must force users to pass
weight_name
. Too much intelligent guessing here would lead to ugly consequences in the code and I am not sure if it's worth bearing the fruits for.
a-r-r-o-w and charchit7
Metadata
Metadata
Assignees
Labels
staleIssues that haven't received updatesIssues that haven't received updates