Skip to content

Commit 62b3c9e

Browse files
unCLIP variant (#2297)
* pipeline_variant * Add docs for when clip_stats_path is specified * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py Co-authored-by: Patrick von Platen <[email protected]> * prepare_latents # Copied from re: @patrickvonplaten * NoiseAugmentor->ImageNormalizer * stable_unclip_prior default to None re: @patrickvonplaten * prepare_prior_extra_step_kwargs * prior denoising scale model input * {DDIM,DDPM}Scheduler -> KarrasDiffusionSchedulers re: @patrickvonplaten * docs * Update docs/source/en/api/pipelines/stable_unclip.mdx Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent e55687e commit 62b3c9e

21 files changed

+2597
-19
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
title: Stable Diffusion
155155
- local: api/pipelines/stable_diffusion_2
156156
title: Stable Diffusion 2
157+
- local: api/pipelines/stable_unclip
158+
title: Stable unCLIP
157159
- local: api/pipelines/stochastic_karras_ve
158160
title: Stochastic Karras VE
159161
- local: api/pipelines/unclip

docs/source/en/api/pipelines/overview.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ available a colab notebook to directly try them out.
6464
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
6565
| [stable_diffusion_2](./stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
6666
| [stable_diffusion_safe](./stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
67+
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation |
68+
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation |
6769
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
6870
| [unclip](./unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
6971
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |

docs/source/en/api/pipelines/stable_diffusion/text2img.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ specific language governing permissions and limitations under the License.
1717
The Stable Diffusion model was created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [runway](https://github.com/runwayml), and [LAION](https://laion.ai/). The [`StableDiffusionPipeline`] is capable of generating photo-realistic images given any text input using Stable Diffusion.
1818

1919
The original codebase can be found here:
20-
- *Stable Diffusion V1*: [CampVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
20+
- *Stable Diffusion V1*: [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
2121
- *Stable Diffusion v2*: [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion)
2222

2323
Available Checkpoints are:
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Stable unCLIP
14+
15+
Stable unCLIP checkpoints are finetuned from [stable diffusion 2.1](./stable_diffusion_2) checkpoints to condition on CLIP image embeddings.
16+
Stable unCLIP also still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used
17+
for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.
18+
19+
## Tips
20+
21+
Stable unCLIP takes a `noise_level` as input during inference. `noise_level` determines how much noise is added
22+
to the image embeddings. A higher `noise_level` increases variation in the final un-noised images. By default,
23+
we do not add any additional noise to the image embeddings i.e. `noise_level = 0`.
24+
25+
### Available checkpoints:
26+
27+
TODO
28+
29+
### Text-to-Image Generation
30+
31+
```python
32+
import torch
33+
from diffusers import StableUnCLIPPipeline
34+
35+
pipe = StableUnCLIPPipeline.from_pretrained(
36+
"fusing/stable-unclip-2-1-l", torch_dtype=torch.float16
37+
) # TODO update model path
38+
pipe = pipe.to("cuda")
39+
40+
prompt = "a photo of an astronaut riding a horse on mars"
41+
images = pipe(prompt).images
42+
images[0].save("astronaut_horse.png")
43+
```
44+
45+
46+
### Text guided Image-to-Image Variation
47+
48+
```python
49+
import requests
50+
import torch
51+
from PIL import Image
52+
from io import BytesIO
53+
54+
from diffusers import StableUnCLIPImg2ImgPipeline
55+
56+
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
57+
"fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
58+
) # TODO update model path
59+
pipe = pipe.to("cuda")
60+
61+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
62+
63+
response = requests.get(url)
64+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
65+
init_image = init_image.resize((768, 512))
66+
67+
prompt = "A fantasy landscape, trending on artstation"
68+
69+
images = pipe(prompt, init_image).images
70+
images[0].save("fantasy_landscape.png")
71+
```
72+
73+
### StableUnCLIPPipeline
74+
75+
[[autodoc]] StableUnCLIPPipeline
76+
- all
77+
- __call__
78+
- enable_attention_slicing
79+
- disable_attention_slicing
80+
- enable_vae_slicing
81+
- disable_vae_slicing
82+
- enable_xformers_memory_efficient_attention
83+
- disable_xformers_memory_efficient_attention
84+
85+
86+
### StableUnCLIPImg2ImgPipeline
87+
88+
[[autodoc]] StableUnCLIPImg2ImgPipeline
89+
- all
90+
- __call__
91+
- enable_attention_slicing
92+
- disable_attention_slicing
93+
- enable_vae_slicing
94+
- disable_vae_slicing
95+
- enable_xformers_memory_efficient_attention
96+
- disable_xformers_memory_efficient_attention
97+

docs/source/en/index.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ available a colab notebook to directly try them out.
5454
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
5555
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
5656
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
57+
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation |
58+
| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation |
5759
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
5860
| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation |
5961
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@
100100
)
101101
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
102102
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
103+
parser.add_argument(
104+
"--stable_unclip",
105+
type=str,
106+
default=None,
107+
required=False,
108+
help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
109+
)
110+
parser.add_argument(
111+
"--stable_unclip_prior",
112+
type=str,
113+
default=None,
114+
required=False,
115+
help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
116+
)
117+
parser.add_argument(
118+
"--clip_stats_path",
119+
type=str,
120+
help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
121+
required=False,
122+
)
103123
args = parser.parse_args()
104124

105125
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -114,5 +134,8 @@
114134
upcast_attention=args.upcast_attention,
115135
from_safetensors=args.from_safetensors,
116136
device=args.device,
137+
stable_unclip=args.stable_unclip,
138+
stable_unclip_prior=args.stable_unclip_prior,
139+
clip_stats_path=args.clip_stats_path,
117140
)
118141
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
StableDiffusionPipeline,
120120
StableDiffusionPipelineSafe,
121121
StableDiffusionUpscalePipeline,
122+
StableUnCLIPImg2ImgPipeline,
123+
StableUnCLIPPipeline,
122124
UnCLIPImageVariationPipeline,
123125
UnCLIPPipeline,
124126
VersatileDiffusionDualGuidedPipeline,

src/diffusers/models/unet_2d_condition.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
9191
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
9292
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
9393
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
94-
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
94+
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
9595
num_class_embeds (`int`, *optional*, defaults to None):
9696
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
9797
class conditioning with `class_embed_type` equal to `None`.
@@ -102,7 +102,9 @@ class conditioning with `class_embed_type` equal to `None`.
102102
time_cond_proj_dim (`int`, *optional*, default to `None`):
103103
The dimension of `cond_proj` layer in timestep embedding.
104104
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
105-
conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
105+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
106+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
107+
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
106108
"""
107109

108110
_supports_gradient_checkpointing = True
@@ -145,6 +147,7 @@ def __init__(
145147
time_cond_proj_dim: Optional[int] = None,
146148
conv_in_kernel: int = 3,
147149
conv_out_kernel: int = 3,
150+
projection_class_embeddings_input_dim: Optional[int] = None,
148151
):
149152
super().__init__()
150153

@@ -211,6 +214,19 @@ def __init__(
211214
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
212215
elif class_embed_type == "identity":
213216
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
217+
elif class_embed_type == "projection":
218+
if projection_class_embeddings_input_dim is None:
219+
raise ValueError(
220+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
221+
)
222+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
223+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
224+
# 2. it projects from an arbitrary input dimension.
225+
#
226+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
227+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
228+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
229+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
214230
else:
215231
self.class_embedding = None
216232

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
StableDiffusionLatentUpscalePipeline,
5656
StableDiffusionPipeline,
5757
StableDiffusionUpscalePipeline,
58+
StableUnCLIPImg2ImgPipeline,
59+
StableUnCLIPPipeline,
5860
)
5961
from .stable_diffusion_safe import StableDiffusionPipelineSafe
6062
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ class StableDiffusionPipelineOutput(BaseOutput):
4545
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
4646
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
4747
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
48+
from .pipeline_stable_unclip import StableUnCLIPPipeline
49+
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
4850
from .safety_checker import StableDiffusionSafetyChecker
51+
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
4952

5053
try:
5154
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):

0 commit comments

Comments
 (0)