Skip to content

Commit eeea6a8

Browse files
committed
Update documentation
Update Docs Add draft documentation and import code
1 parent f953739 commit eeea6a8

File tree

6 files changed

+263
-1
lines changed

6 files changed

+263
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@
191191
title: MultiDiffusion Panorama
192192
- local: api/pipelines/stable_diffusion/controlnet
193193
title: Text-to-Image Generation with ControlNet Conditioning
194+
- local: api/pipelines/stable_diffusion/diffedit
195+
title: DiffEdit
194196
title: Stable Diffusion
195197
- local: api/pipelines/stable_diffusion_2
196198
title: Stable Diffusion 2
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Zero-shot Diffusion-based Semantic Image Editing with Mask Guidance
14+
15+
## Overview
16+
17+
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427).
18+
19+
The abstract of the paper is the following:
20+
21+
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
22+
23+
Resources:
24+
25+
* [Project Page](https://pix2pixzero.github.io/).
26+
* [Paper](https://arxiv.org/abs/2210.11427).
27+
* [Blog Post with Demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
28+
* [Implementation on Github](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/origin.png).
29+
30+
## Tips
31+
32+
* The pipeline can be conditioned on real input images. Check out the code examples below to know more.
33+
* The pipeline exposes two arguments namely `source_embeds` and `target_embeds`
34+
that let you control the direction of the semantic edits in the final image to be generated. Let's say,
35+
you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
36+
this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to
37+
`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details.
38+
* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking
39+
the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough".
40+
* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
41+
* Swap the `source_embeds` and `target_embeds`.
42+
* Change the input prompt to include "dog".
43+
* To learn more about how the source and target embeddings are generated, refer to the [original
44+
paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings.
45+
* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic.
46+
47+
## Available Pipelines:
48+
49+
| Pipeline | Tasks
50+
|---|---|
51+
| [StableDiffusionDiffEditPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py) | *Text-Based Image Editing*
52+
53+
<!-- TODO: add Colab -->
54+
55+
## Usage example
56+
57+
### Based on an input image
58+
59+
When the pipeline is conditioned on an input image, we first obtain an inverted
60+
noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then
61+
the inverted noise is used to start the generation process.
62+
63+
First, let's load our pipeline:
64+
65+
```py
66+
import torch
67+
from transformers import BlipForConditionalGeneration, BlipProcessor
68+
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
69+
70+
captioner_id = "Salesforce/blip-image-captioning-base"
71+
processor = BlipProcessor.from_pretrained(captioner_id)
72+
model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
73+
74+
sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
75+
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
76+
sd_model_ckpt,
77+
caption_generator=model,
78+
caption_processor=processor,
79+
torch_dtype=torch.float16,
80+
safety_checker=None,
81+
)
82+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
83+
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
84+
pipeline.enable_model_cpu_offload()
85+
```
86+
87+
Then, we load an input image for conditioning and obtain a suitable caption for it:
88+
89+
```py
90+
import requests
91+
from PIL import Image
92+
93+
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
94+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
95+
caption = pipeline.generate_caption(raw_image)
96+
```
97+
98+
Then we employ the generated caption and the input image to get the inverted latents:
99+
100+
```py
101+
generator = torch.manual_seed(0)
102+
inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
103+
```
104+
105+
Then we employ the source and target prompts to generate the editing mask:
106+
107+
```py
108+
# See the "Generating source and target embeddings" section below to
109+
# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
110+
source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
111+
target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
112+
113+
source_embeds = pipeline.get_embeds(source_prompts, batch_size=2)
114+
target_embeds = pipeline.get_embeds(target_prompts, batch_size=2)
115+
mask_image = pipeline.compute_mask(
116+
image=raw_image,
117+
prompt_embeds=target_embeds,
118+
mask_prompt_embeds=source_embeds,
119+
generator=generator,
120+
)
121+
```
122+
123+
Now, generate the image with the inverted latents and semantically generated mask:
124+
125+
```py
126+
image = pipeline(
127+
prompt_embeds=target_embeds,
128+
num_inference_steps=50,
129+
generator=generator,
130+
inverted_latents=inv_latents,
131+
mask_image=mask_image,
132+
negative_prompt=caption,
133+
).images[0]
134+
image.save("edited_image.png")
135+
```
136+
137+
## Generating source and target embeddings
138+
139+
The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering
140+
edit directions. However, we can also leverage open source and public models for the same purpose.
141+
Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model
142+
for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for
143+
computing embeddings on the generated captions.
144+
145+
**1. Load the generation model**:
146+
147+
```py
148+
import torch
149+
from transformers import AutoTokenizer, T5ForConditionalGeneration
150+
151+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
152+
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
153+
```
154+
155+
**2. Construct a starting prompt**:
156+
157+
```py
158+
source_concept = "cat"
159+
target_concept = "dog"
160+
161+
source_text = f"Provide a caption for images containing a {source_concept}. "
162+
"The captions should be in English and should be no longer than 150 characters."
163+
164+
target_text = f"Provide a caption for images containing a {target_concept}. "
165+
"The captions should be in English and should be no longer than 150 characters."
166+
```
167+
168+
Here, we're interested in the "cat -> dog" direction.
169+
170+
**3. Generate captions**:
171+
172+
We can use a utility like so for this purpose.
173+
174+
```py
175+
def generate_captions(input_prompt):
176+
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
177+
178+
outputs = model.generate(
179+
input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
180+
)
181+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
182+
```
183+
184+
And then we just call it to generate our captions:
185+
186+
```py
187+
source_captions = generate_captions(source_text)
188+
target_captions = generate_captions(target_concept)
189+
```
190+
191+
We encourage you to play around with the different parameters supported by the
192+
`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for.
193+
194+
**4. Load the embedding model**:
195+
196+
Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model.
197+
198+
```py
199+
from diffusers import StableDiffusionPix2PixZeroPipeline
200+
201+
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
202+
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
203+
)
204+
pipeline = pipeline.to("cuda")
205+
tokenizer = pipeline.tokenizer
206+
text_encoder = pipeline.text_encoder
207+
```
208+
209+
**5. Compute embeddings**:
210+
211+
```py
212+
import torch
213+
214+
def embed_captions(sentences, tokenizer, text_encoder, device="cuda"):
215+
with torch.no_grad():
216+
embeddings = []
217+
for sent in sentences:
218+
text_inputs = tokenizer(
219+
sent,
220+
padding="max_length",
221+
max_length=tokenizer.model_max_length,
222+
truncation=True,
223+
return_tensors="pt",
224+
)
225+
text_input_ids = text_inputs.input_ids
226+
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
227+
embeddings.append(prompt_embeds)
228+
return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
229+
230+
source_embeddings = embed_captions(source_captions, tokenizer, text_encoder)
231+
target_embeddings = embed_captions(target_captions, tokenizer, text_encoder)
232+
```
233+
234+
And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process.
235+
236+
Now, you can use these embeddings directly while calling the pipeline:
237+
238+
```py
239+
from diffusers import DDIMScheduler
240+
241+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
242+
243+
images = pipeline(
244+
prompt,
245+
source_embeds=source_embeddings,
246+
target_embeds=target_embeddings,
247+
num_inference_steps=50,
248+
cross_attention_guidance_amount=0.15,
249+
).images
250+
images[0].save("edited_image_dog.png")
251+
```
252+
253+
## StableDiffusionDiffEditPipeline
254+
[[autodoc]] StableDiffusionDiffEditPipeline
255+
- __call__
256+
- all

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
StableDiffusionAttendAndExcitePipeline,
121121
StableDiffusionControlNetPipeline,
122122
StableDiffusionDepth2ImgPipeline,
123+
StableDiffusionDiffEditPipeline,
123124
StableDiffusionImageVariationPipeline,
124125
StableDiffusionImg2ImgPipeline,
125126
StableDiffusionInpaintPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
StableDiffusionAttendAndExcitePipeline,
5454
StableDiffusionControlNetPipeline,
5555
StableDiffusionDepth2ImgPipeline,
56+
StableDiffusionDiffEditPipeline,
5657
StableDiffusionImageVariationPipeline,
5758
StableDiffusionImg2ImgPipeline,
5859
StableDiffusionInpaintPipeline,

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ class StableDiffusionPipelineOutput(BaseOutput):
7474
except OptionalDependencyNotAvailable:
7575
from ...utils.dummy_torch_and_transformers_objects import (
7676
StableDiffusionDepth2ImgPipeline,
77+
StableDiffusionDiffEditPipeline,
7778
StableDiffusionPix2PixZeroPipeline,
7879
)
7980
else:
8081
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
82+
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
8183
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
8284

8385

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline):
232232
feature_extractor ([`CLIPImageProcessor`]):
233233
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
234234
"""
235-
_optional_components = ["safety_checker", "feature_extractor"]
235+
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
236236

237237
def __init__(
238238
self,

0 commit comments

Comments
 (0)