Skip to content
133 changes: 113 additions & 20 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from io import BytesIO
from typing import Optional
from pathlib import Path

import requests
import torch
Expand All @@ -31,32 +32,36 @@
CLIPVisionModelWithProjection,
)

from diffusers import (
from huggingface_hub import hf_hub_download

from ...models import (
AutoencoderKL,
UNet2DConditionModel,
PriorTransformer,
ControlNetModel,
)

from ...schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
PriorTransformer,
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
UnCLIPScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer

from ...utils import is_omegaconf_available, is_safetensors_available, logging
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from .safety_checker import StableDiffusionSafetyChecker

from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer

from ..pipeline_utils import DiffusionPipeline

from ...utils import is_omegaconf_available, is_safetensors_available, logging, DIFFUSERS_CACHE, HF_HUB_OFFLINE
from ...utils.import_utils import BACKENDS_MAPPING


Expand Down Expand Up @@ -990,7 +995,8 @@ def download_from_original_stable_diffusion_ckpt(
clip_stats_path: Optional[str] = None,
controlnet: Optional[bool] = None,
load_safety_checker: bool = True,
) -> StableDiffusionPipeline:
pipeline_class: DiffusionPipeline = None,
) -> DiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file.
Expand Down Expand Up @@ -1026,12 +1032,29 @@ def download_from_original_stable_diffusion_ckpt(
Whether the attention computation should always be upcasted. This is necessary when running stable
diffusion 2.1.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""

# import pipelines here to avoid circular import error when using from_ckpt method
from diffusers import (
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,
)

if pipeline_class is None:
pipeline_class = StableDiffusionPipeline

if prediction_type == "v-prediction":
prediction_type = "v_prediction"

Expand Down Expand Up @@ -1193,7 +1216,7 @@ def download_from_original_stable_diffusion_ckpt(
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -1293,7 +1316,7 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
else:
pipe = StableDiffusionPipeline(
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
Expand All @@ -1320,7 +1343,7 @@ def download_controlnet_from_original_ckpt(
upcast_attention: Optional[bool] = None,
device: str = None,
from_safetensors: bool = False,
) -> StableDiffusionPipeline:
) -> DiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])

Expand Down Expand Up @@ -1361,3 +1384,73 @@ def download_controlnet_from_original_ckpt(
)

return controlnet_model


class FromCkptMixin:
@classmethod
def from_ckpt(cls, model_path_or_checkpoint, **kwargs):
pipeline_name = cls.__name__

file_extension = model_path_or_checkpoint.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"

stable_unclip = None
controlnet = False
if pipeline_name == "StableDiffusionControlNetPipeline":
model_type = "FrozenCLIPEmbedder"
controlnet = True

elif "StableDiffusion" in pipeline_name:
model_type = "FrozenCLIPEmbedder"

elif pipeline_name == "StableUnCLIPPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "txt2img"

elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "img2img"

elif pipeline_name == "PaintByExamplePipeline":
model_type == "PaintByExample"

elif pipeline_name == "LDMTextToImagePipeline":
model_type == "LDMTextToImage"

else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")

# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(model_path_or_checkpoint)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))

cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
revision = kwargs.pop("revision", None)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)

model_path_or_checkpoint = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)

return download_from_original_stable_diffusion_ckpt(
model_path_or_checkpoint,
pipeline_class=cls,
model_type=model_type,
stable_unclip=stable_unclip,
controlnet=controlnet,
from_safetensors=from_safetensors,
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

from .convert_from_ckpt import FromCkptMixin

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand All @@ -52,7 +53,7 @@
"""


class StableDiffusionPipeline(DiffusionPipeline):
class StableDiffusionPipeline(DiffusionPipeline, FromCkptMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

r"""
Pipeline for text-to-image generation using Stable Diffusion.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

from .convert_from_ckpt import FromCkptMixin

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -91,7 +92,7 @@ def preprocess(image):
return image


class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, FromCkptMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

from .convert_from_ckpt import FromCkptMixin

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask


class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, FromCkptMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.

Expand Down
21 changes: 21 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,27 @@ def test_download_broken_variant(self):

diffusers.utils.import_utils._safetensors_available = True

def test_download_from_ckpt(self):
with tempfile.TemporaryDirectory() as tmpdirname:
ckpt_paths = [
"runwayml/stable-diffusion-v1-5/v1-5-pruned-emaonly.ckpt",
"WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix_base.ckpt",
]

for ckpt_path in ckpt_paths:
StableDiffusionPipeline.from_ckpt(ckpt_path, cache_dir=tmpdirname)

ckpt_names = [os.path.basename(ckpt_path) for ckpt_path in ckpt_paths]

files = []
for cache in os.listdir(tmpdirname):
snapshots = os.path.join(tmpdirname, cache, "snapshots")
all_root_files = [t[-1] for t in os.walk(snapshots)]
files += [item for sublist in all_root_files for item in sublist]

# check that downloaded filenames match checkpoint filenames
assert set(ckpt_names) == set(files)


class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
Expand Down