17
17
import os
18
18
import re
19
19
import tempfile
20
+ from typing import Optional
20
21
21
22
import requests
22
23
import torch
@@ -787,8 +788,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
787
788
model_type : str = None ,
788
789
extract_ema : bool = False ,
789
790
scheduler_type : str = "pndm" ,
790
- num_in_channels : int = None ,
791
- upcast_attention : bool = None ,
791
+ num_in_channels : Optional [ int ] = None ,
792
+ upcast_attention : Optional [ bool ] = None ,
792
793
device : str = None ,
793
794
from_safetensors : bool = False ,
794
795
) -> StableDiffusionPipeline :
@@ -800,28 +801,36 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
800
801
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
801
802
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
802
803
803
- :param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
804
- corresponding to the original architecture. If `None`, will be
805
- automatically inferred by looking for a key that only exists in SD2.0 models.
806
- :param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable
807
- Diffusion v2
804
+ Args:
805
+ checkpoint_path (`str`): Path to `.ckpt` file.
806
+ original_config_file (`str`):
807
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
808
+ inferred by looking for a key that only exists in SD2.0 models.
809
+ image_size (`int`, *optional*, defaults to 512):
810
+ The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
808
811
Base. Use 768 for Stable Diffusion v2.
809
- :param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
810
- v1.X and Stable
812
+ prediction_type (`str`, *optional*):
813
+ The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
811
814
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
812
- :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
813
- inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
814
- "euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
815
- `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
816
- checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
817
- or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
818
- quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
819
- :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
820
- running
821
- stable diffusion 2.1.
822
- :param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
823
- `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
824
- StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
815
+ num_in_channels (`int`, *optional*, defaults to None):
816
+ The number of input channels. If `None`, it will be automatically inferred.
817
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
818
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
819
+ "ddim"]`.
820
+ model_type (`str`, *optional*, defaults to `None`):
821
+ The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
822
+ "FrozenCLIPEmbedder", "PaintByExample"]`.
823
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
824
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
825
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
826
+ inference. Non-EMA weights are usually better to continue fine-tuning.
827
+ upcast_attention (`bool`, *optional*, defaults to `None`):
828
+ Whether the attention computation should always be upcasted. This is necessary when running stable
829
+ diffusion 2.1.
830
+ device (`str`, *optional*, defaults to `None`):
831
+ The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
832
+ in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
833
+ StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
825
834
"""
826
835
if prediction_type == "v-prediction" :
827
836
prediction_type = "v_prediction"
0 commit comments