Skip to content

Commit a8ce413

Browse files
committed
init
1 parent fd5c3c0 commit a8ce413

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import re
1919
import tempfile
20+
from typing import Optional
2021

2122
import requests
2223
import torch
@@ -787,8 +788,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
787788
model_type: str = None,
788789
extract_ema: bool = False,
789790
scheduler_type: str = "pndm",
790-
num_in_channels: int = None,
791-
upcast_attention: bool = None,
791+
num_in_channels: Optional[int] = None,
792+
upcast_attention: Optional[bool] = None,
792793
device: str = None,
793794
from_safetensors: bool = False,
794795
) -> StableDiffusionPipeline:
@@ -800,28 +801,36 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
800801
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
801802
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
802803
803-
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
804-
corresponding to the original architecture. If `None`, will be
805-
automatically inferred by looking for a key that only exists in SD2.0 models.
806-
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable
807-
Diffusion v2
804+
Args:
805+
checkpoint_path (`str`): Path to `.ckpt` file.
806+
original_config_file (`str`):
807+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
808+
inferred by looking for a key that only exists in SD2.0 models.
809+
image_size (`int`, optional*, defaults to 512):
810+
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
808811
Base. Use 768 for Stable Diffusion v2.
809-
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
810-
v1.X and Stable
812+
prediction_type (`str`, *optional*):
813+
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
811814
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
812-
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
813-
inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
814-
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
815-
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
816-
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
817-
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
818-
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
819-
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
820-
running
821-
stable diffusion 2.1.
822-
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
823-
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
824-
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
815+
num_in_channels (`int`, *optional*, defaults to None):
816+
The number of input channels. If `None` number of input channels will be automatically inferred.
817+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
818+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
819+
"ddim"]`.
820+
model_type (`str`, *optional*, defaults to `None`):
821+
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
822+
"FrozenCLIPEmbedder", "PaintByExample"]`.
823+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
824+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
825+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
826+
inference. Non-EMA weights are usually better to continue fine-tuning.
827+
upcast_attention (`bool`, *optional*, defaults to `None`):
828+
Whether the attention computation should always be upcasted. This is necessary when running stable
829+
diffusion 2.1.
830+
device (`str`, *optional*, defaults to `None`):
831+
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
832+
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
833+
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
825834
"""
826835
if prediction_type == "v-prediction":
827836
prediction_type = "v_prediction"

0 commit comments

Comments
 (0)