Skip to content

convert ckpt script docstring fixes #2293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import re
import tempfile
from typing import Optional

import requests
import torch
Expand Down Expand Up @@ -787,8 +788,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
model_type: str = None,
extract_ema: bool = False,
scheduler_type: str = "pndm",
num_in_channels: int = None,
upcast_attention: bool = None,
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
device: str = None,
from_safetensors: bool = False,
) -> StableDiffusionPipeline:
Expand All @@ -800,28 +801,36 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible.

:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
corresponding to the original architecture. If `None`, will be
automatically inferred by looking for a key that only exists in SD2.0 models.
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable
Diffusion v2
Args:
checkpoint_path (`str`): Path to `.ckpt` file.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
inferred by looking for a key that only exists in SD2.0 models.
image_size (`int`, *optional*, defaults to 512):
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
v1.X and Stable
prediction_type (`str`, *optional*):
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running
stable diffusion 2.1.
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
num_in_channels (`int`, *optional*, defaults to None):
The number of input channels. If `None`, it will be automatically inferred.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
model_type (`str`, *optional*, defaults to `None`):
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
"FrozenCLIPEmbedder", "PaintByExample"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted. This is necessary when running stable
diffusion 2.1.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
Expand Down