Skip to content

[Design Discussion] allowing from_pretrained() to also load single file checkpoints #6461

@sayakpaul

Description

@sayakpaul

Since we were considering adding an option like single_file_format to save_pretrained() of DiffusionPipeline, it makes sense to have something similar in from_pretrained() to have better feature parity.

We currently support loading single file checkpoints in DiffusionPipeline via from_single_file(). Some examples below:

from diffusers import StableDiffusionPipeline

# Download pipeline from huggingface.co and cache.
pipeline = StableDiffusionPipeline.from_single_file(
    "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
)

# Download pipeline from local file
# file is downloaded under ./v1-5-pruned-emaonly.ckpt
pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")

# Enable float16 and move to GPU
pipeline = StableDiffusionPipeline.from_single_file(
    "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
    torch_dtype=torch.float16,
)
pipeline.to("cuda")

(Taken from the docs here)

Proposed API design

Calling from_pretrained() on a DiffusionPipeline requires users to mandatorily pass pretrained_model_name_or_path, which can be a repo id on the Hub or a local directory containing checkpoints in the diffusers format.

(Docs)

Now, if we want to add support for loading a compatible single file checkpoint in from_pretrained(), we could have an API like so:

from diffusers import DiffusionPipeline

repo_id = "WarriorMama777/OrangeMixs"
pipe = DiffusionPipeline.from_pretrained(repo_id, weight_name="Models/AbyssOrangeMix/AbyssOrangeMix.safetensors")
  • Like before, repo_id could either be an actual repo id on the Hub or a local directory.
  • weight_name can either be just the filename of the single file checkpoint to be loaded or the relative path to the checkpoint (w.r.t the underlying repo / directory).
  • When weight_name is provided in from_pretrained():
    • We immediately check if the file exists in the repository or the directory and flag an error if necessary in case it's not found.
    • Once it's checked, we hit the codepath that we're hitting currently when using from_single_file(). Logic to do that should be completely separated as a utility and should not come into from_pretrained(). We can just call the utility from from_pretrained().
    • How can we detect errors here as early as possible? What if the checkpoint is not compatible or doesn't have all the components we need (what if the vae or any other component is missing)? Is there any robust way?
  • Once this support is foolproof, we can start deprecating the use of from_single_file().

Some thoughts

  • I don't think this is a very new design. Users are already familiar with weight_name and how it's to be used through load_lora_weights() (which is quite popular at this point IMO).
  • I think we must force users to pass weight_name. Too much intelligent guessing here would lead to ugly consequences in the code and I am not sure if it's worth bearing the fruits for.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions