Skip to content

Commit e50c25d

Browse files
manuelbrackPatrickSchrMLpatrickvonplaten
authored
Add Safe Stable Diffusion Pipeline (#1244)
* Add pipeline_stable_diffusion_safe.py to pipelines * Fix repository consistency Ran make fix-copies after adding new pipline * Add Paper/Equation reference for parameters to doc string * Ensure code style and quality * Perform code refactoring * Fix copies inherited from merge with huggingface/main * Add docs * Fix code style * Fix errors in documentation * Fix refactoring error * remove debugging print statement * added Safe Latent Diffusion tests * Fix style * Fix style * Add pre-defined safety configurations * Fix line-break * fix some tests * finish * Change safety checker * Add missing safety_checker.py file * Remove unused imports Co-authored-by: PatrickSchrML <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 182eb95 commit e50c25d

File tree

12 files changed

+1450
-1
lines changed

12 files changed

+1450
-1
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
title: "Score SDE VE"
107107
- local: api/pipelines/stable_diffusion
108108
title: "Stable Diffusion"
109+
- local: api/pipelines/stable_diffusion_safe
110+
title: "Safe Stable Diffusion"
109111
- local: api/pipelines/stochastic_karras_ve
110112
title: "Stochastic Karras VE"
111113
- local: api/pipelines/dance_diffusion

docs/source/api/pipelines/overview.mdx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ available a colab notebook to directly try them out.
5858
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
5959
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
6060
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
61-
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
61+
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
62+
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
6263
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
6364

6465

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Safe Stable Diffusion
14+
15+
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
16+
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
17+
18+
The abstract of the paper is the following:
19+
20+
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
21+
22+
23+
*Overview*:
24+
25+
| Pipeline | Tasks | Colab | Demo
26+
|---|---|:---:|:---:|
27+
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
28+
29+
## Tips
30+
31+
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
32+
33+
### Run Safe Stable Diffusion
34+
35+
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
36+
37+
### Interacting with the Safety Concept
38+
39+
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
40+
```python
41+
>>> from diffusers import StableDiffusionPipelineSafe
42+
43+
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
44+
>>> pipeline.safety_concept
45+
```
46+
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
47+
48+
### Using pre-defined safety configurations
49+
50+
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
51+
52+
```python
53+
>>> from diffusers import StableDiffusionPipelineSafe
54+
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
55+
56+
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
57+
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
58+
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
59+
```
60+
61+
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
62+
63+
### How to load and use different schedulers.
64+
65+
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
66+
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
67+
68+
```python
69+
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
70+
71+
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
72+
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
73+
74+
>>> # or
75+
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
76+
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
77+
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
78+
... )
79+
```
80+
81+
82+
## StableDiffusionSafePipelineOutput
83+
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
84+
85+
## StableDiffusionPipelineSafe
86+
[[autodoc]] StableDiffusionPipelineSafe
87+
- __call__
88+
- enable_attention_slicing
89+
- disable_attention_slicing
90+

docs/source/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ available a colab notebook to directly try them out.
4848
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
4949
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
5050
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
51+
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
5152
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
5253
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
5354

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
StableDiffusionInpaintPipeline,
7474
StableDiffusionInpaintPipelineLegacy,
7575
StableDiffusionPipeline,
76+
StableDiffusionPipelineSafe,
7677
VQDiffusionPipeline,
7778
)
7879
else:

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
StableDiffusionInpaintPipelineLegacy,
2525
StableDiffusionPipeline,
2626
)
27+
from .stable_diffusion_safe import StableDiffusionPipelineSafe
2728
from .vq_diffusion import VQDiffusionPipeline
2829

2930
if is_transformers_available() and is_onnx_available():
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Optional, Union
4+
5+
import numpy as np
6+
7+
import PIL
8+
from PIL import Image
9+
10+
from ...utils import BaseOutput, is_torch_available, is_transformers_available
11+
12+
13+
@dataclass
14+
class SafetyConfig(object):
15+
WEAK = {
16+
"sld_warmup_steps": 15,
17+
"sld_guidance_scale": 20,
18+
"sld_threshold": 0.0,
19+
"sld_momentum_scale": 0.0,
20+
"sld_mom_beta": 0.0,
21+
}
22+
MEDIUM = {
23+
"sld_warmup_steps": 10,
24+
"sld_guidance_scale": 1000,
25+
"sld_threshold": 0.01,
26+
"sld_momentum_scale": 0.3,
27+
"sld_mom_beta": 0.4,
28+
}
29+
STRONG = {
30+
"sld_warmup_steps": 7,
31+
"sld_guidance_scale": 2000,
32+
"sld_threshold": 0.025,
33+
"sld_momentum_scale": 0.5,
34+
"sld_mom_beta": 0.7,
35+
}
36+
MAX = {
37+
"sld_warmup_steps": 0,
38+
"sld_guidance_scale": 5000,
39+
"sld_threshold": 1.0,
40+
"sld_momentum_scale": 0.5,
41+
"sld_mom_beta": 0.7,
42+
}
43+
44+
45+
@dataclass
46+
class StableDiffusionSafePipelineOutput(BaseOutput):
47+
"""
48+
Output class for Safe Stable Diffusion pipelines.
49+
50+
Args:
51+
images (`List[PIL.Image.Image]` or `np.ndarray`)
52+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
53+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
54+
nsfw_content_detected (`List[bool]`)
55+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
56+
(nsfw) content, or `None` if safety checking could not be performed.
57+
images (`List[PIL.Image.Image]` or `np.ndarray`)
58+
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
59+
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
60+
applied_safety_concept (`str`)
61+
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
62+
"""
63+
64+
images: Union[List[PIL.Image.Image], np.ndarray]
65+
nsfw_content_detected: Optional[List[bool]]
66+
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
67+
applied_safety_concept: Optional[str]
68+
69+
70+
if is_transformers_available() and is_torch_available():
71+
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
72+
from .safety_checker import SafeStableDiffusionSafetyChecker

0 commit comments

Comments
 (0)