Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions examples/inference/inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer


def preprocess(image):
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
Expand All @@ -20,15 +20,16 @@ def preprocess(image):
image = torch.from_numpy(image)
return 2.0 * image - 1.0


def preprocess_mask(mask):
mask=mask.convert("L")
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask,(4,1,1))
mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?
mask = 1 - mask #repaint white, keep black
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask

Expand Down Expand Up @@ -90,25 +91,25 @@ def __call__(

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

#preprocess image
init_image = preprocess(init_image).to(self.device)
# preprocess image
init_image = preprocess_image(init_image).to(self.device)

# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample()
init_latents = 0.18215 * init_latents

# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)
init_latents_orig = init_latents

# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)

#check sizes
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError(f"The mask and init_image should be the same size!")

# prepare init_latents noise to latents
init_latents = torch.cat([init_latents] * batch_size)

# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
Expand Down Expand Up @@ -172,9 +173,9 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

#masking
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = ( init_latents_proper * mask ) + ( latents * (1-mask) )
latents = (init_latents_proper * mask) + (latents * (1 - mask))

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
43 changes: 43 additions & 0 deletions examples/inference/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,46 @@ You can also run this example on colab [![Open In Colab](https://colab.research.
## Tweak prompts reusing seeds and latents

You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).


## In-painting using Stable Diffusion

The `inpainting.py` script implements `StableDiffusionInpaintingPipeline`. This script lets you edit specific parts of an image by providing a mask and text prompt.

### How to use it

```python
from io import BytesIO

from torch import autocast
import requests
import PIL

from inpainting import StableDiffusionInpaintingPipeline

def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))

device = "cuda"
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
).to(device)

prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]

images[0].save("cat_on_bench.png")
```

You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)