diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index de33ba616d0a..cc880f3e0b81 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -152,6 +152,8 @@ title: DDPM - local: api/pipelines/dit title: DiT + - local: api/pipelines/if + title: IF - local: api/pipelines/latent_diffusion title: Latent Diffusion - local: api/pipelines/paint_by_example diff --git a/docs/source/en/api/pipelines/if.mdx b/docs/source/en/api/pipelines/if.mdx new file mode 100644 index 000000000000..5d3b292587f6 --- /dev/null +++ b/docs/source/en/api/pipelines/if.mdx @@ -0,0 +1,523 @@ + + +# IF + +## Overview + +DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. +The model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: +- Stage 1: a base model that generates 64x64 px image based on text prompt, +- Stage 2: a 64x64 px => 256x256 px super-resolution model, and a +- Stage 3: a 256x256 px => 1024x1024 px super-resolution model +Stage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings, +which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. +Stage 3 is [Stability's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler). +The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. +Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis. + +## Usage + +Before you can use IF, you need to accept its usage conditions. To do so: +1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in +2. Accept the license on the model card of [DeepFloyd/IF-I-IF-v1.0](https://huggingface.co/DeepFloyd/IF-I-IF-v1.0) and [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0) +3. Make sure to login locally. Install `huggingface_hub` +```sh +pip install huggingface_hub --upgrade +``` + +run the login function in a Python shell + +```py +from huggingface_hub import login + +login() +``` + +and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). + +Next we install `diffusers` and dependencies: + +```sh +pip install diffusers accelerate transformers safetensors +``` + +The following sections give more in-detail examples of how to use IF. Specifically: + +- [Text-to-Image Generation](#text-to-image-generation) +- [Image-to-Image Generation](#text-guided-image-to-image-generation) +- [Inpainting](#text-guided-inpainting-generation) +- [Reusing model weights](#converting-between-different-pipelines) +- [Speed optimization](#optimizing-for-speed) +- [Memory optimization](#optimizing-for-memory) + +**Available checkpoints** +- *Stage-1* + - [DeepFloyd/IF-I-IF-v1.0](https://huggingface.co/DeepFloyd/IF-I-IF-v1.0) + - [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) + - [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0) + +- *Stage-2* + - [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0) + - [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0) + +- *Stage-3* + - [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) + +**Demo** +[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF) + +**Google Colab** +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) + +### Text-to-Image Generation + +By default diffusers makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) +to run the whole IF pipeline with as little as 14 GB of VRAM. + +```python +from diffusers import DiffusionPipeline +from diffusers.utils import pt_to_pil +import torch + +# stage 1 +stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +stage_1.enable_model_cpu_offload() + +# stage 2 +stage_2 = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 +) +stage_2.enable_model_cpu_offload() + +# stage 3 +safety_modules = { + "feature_extractor": stage_1.feature_extractor, + "safety_checker": stage_1.safety_checker, + "watermarker": stage_1.watermarker, +} +stage_3 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 +) +stage_3.enable_model_cpu_offload() + +prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' +generator = torch.manual_seed(1) + +# text embeds +prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) + +# stage 1 +image = stage_1( + prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" +).images +pt_to_pil(image)[0].save("./if_stage_I.png") + +# stage 2 +image = stage_2( + image=image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + generator=generator, + output_type="pt", +).images +pt_to_pil(image)[0].save("./if_stage_II.png") + +# stage 3 +image = stage_3(prompt=prompt, image=image, noise_level=100, generator=generator).images +image[0].save("./if_stage_III.png") +``` + +### Text Guided Image-to-Image Generation + +The same IF model weights can be used for text-guided image-to-image translation or image variation. +In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines. + +**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines +without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines). + +```python +from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline +from diffusers.utils import pt_to_pil + +import torch + +from PIL import Image +import requests +from io import BytesIO + +# download image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +response = requests.get(url) +original_image = Image.open(BytesIO(response.content)).convert("RGB") +original_image = original_image.resize((768, 512)) + +# stage 1 +stage_1 = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +stage_1.enable_model_cpu_offload() + +# stage 2 +stage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 +) +stage_2.enable_model_cpu_offload() + +# stage 3 +safety_modules = { + "feature_extractor": stage_1.feature_extractor, + "safety_checker": stage_1.safety_checker, + "watermarker": stage_1.watermarker, +} +stage_3 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 +) +stage_3.enable_model_cpu_offload() + +prompt = "A fantasy landscape in style minecraft" +generator = torch.manual_seed(1) + +# text embeds +prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) + +# stage 1 +image = stage_1( + image=original_image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + generator=generator, + output_type="pt", +).images +pt_to_pil(image)[0].save("./if_stage_I.png") + +# stage 2 +image = stage_2( + image=image, + original_image=original_image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + generator=generator, + output_type="pt", +).images +pt_to_pil(image)[0].save("./if_stage_II.png") + +# stage 3 +image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images +image[0].save("./if_stage_III.png") +``` + +### Text Guided Inpainting Generation + +The same IF model weights can be used for text-guided image-to-image translation or image variation. +In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines. + +**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines +without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines). + +```python +from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline +from diffusers.utils import pt_to_pil +import torch + +from PIL import Image +import requests +from io import BytesIO + +# download image +url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" +response = requests.get(url) +original_image = Image.open(BytesIO(response.content)).convert("RGB") +original_image = original_image + +# download mask +url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" +response = requests.get(url) +mask_image = Image.open(BytesIO(response.content)) +mask_image = mask_image + +# stage 1 +stage_1 = IFInpaintingPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +stage_1.enable_model_cpu_offload() + +# stage 2 +stage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 +) +stage_2.enable_model_cpu_offload() + +# stage 3 +safety_modules = { + "feature_extractor": stage_1.feature_extractor, + "safety_checker": stage_1.safety_checker, + "watermarker": stage_1.watermarker, +} +stage_3 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 +) +stage_3.enable_model_cpu_offload() + +prompt = "blue sunglasses" +generator = torch.manual_seed(1) + +# text embeds +prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) + +# stage 1 +image = stage_1( + image=original_image, + mask_image=mask_image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + generator=generator, + output_type="pt", +).images +pt_to_pil(image)[0].save("./if_stage_I.png") + +# stage 2 +image = stage_2( + image=image, + original_image=original_image, + mask_image=mask_image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + generator=generator, + output_type="pt", +).images +pt_to_pil(image)[0].save("./if_stage_II.png") + +# stage 3 +image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images +image[0].save("./if_stage_III.png") +``` + +### Converting between different pipelines + +In addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other. + +```python +from diffusers import IFPipeline, IFSuperResolutionPipeline + +pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0") +pipe_2 = IFSuperResolutionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0") + + +from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline + +pipe_1 = IFImg2ImgPipeline(**pipe_1.components) +pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components) + + +from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline + +pipe_1 = IFInpaintingPipeline(**pipe_1.components) +pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components) +``` + +### Optimizing for speed + +The simplest optimization to run IF faster is to move all model components to the GPU. + +```py +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +pipe.to("cuda") +``` + +You can also run the diffusion process for a shorter number of timesteps. + +This can either be done with the `num_inference_steps` argument + +```py +pipe("", num_inference_steps=30) +``` + +Or with the `timesteps` argument + +```py +from diffusers.pipelines.deepfloyd_if import fast27_timesteps + +pipe("", timesteps=fast27_timesteps) +``` + +When doing image variation or inpainting, you can also decrease the number of timesteps +with the strength argument. The strength argument is the amount of noise to add to +the input image which also determines how many steps to run in the denoising process. +A smaller number will vary the image less but run faster. + +```py +pipe = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(image=image, prompt="", strength=0.3).images +``` + +You can also use [`torch.compile`](../../optimization/torch2.0). Note that we have not exhaustively tested `torch.compile` +with IF and it might not give expected results. + +```py +import torch + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +pipe.to("cuda") + +pipe.text_encoder = torch.compile(pipe.text_encoder) +pipe.unet = torch.compile(pipe.unet) +``` + +### Optimizing for memory + +When optimizing for GPU memory, we can use the standard diffusers cpu offloading APIs. + +Either the model based CPU offloading, + +```py +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() +``` + +or the more aggressive layer based CPU offloading. + +```py +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) +pipe.enable_sequential_cpu_offload() +``` + +Additionally, T5 can be loaded in 8bit precision + +```py +from transformers import T5EncoderModel + +text_encoder = T5EncoderModel.from_pretrained( + "DeepFloyd/IF-I-IF-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit" +) + +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-I-IF-v1.0", + text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder + unet=None, + device_map="auto", +) + +prompt_embeds, negative_embeds = pipe.encode_prompt("") +``` + +For CPU RAM constrained machines like google colab free tier where we can't load all +model components to the CPU at once, we can manually only load the pipeline with +the text encoder or unet when the respective model components are needed. + +```py +from diffusers import IFPipeline, IFSuperResolutionPipeline +import torch +import gc +from transformers import T5EncoderModel +from diffusers.utils import pt_to_pil + +text_encoder = T5EncoderModel.from_pretrained( + "DeepFloyd/IF-I-IF-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit" +) + +# text to image + +pipe = DiffusionPipeline.from_pretrained( + "DeepFloyd/IF-I-IF-v1.0", + text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder + unet=None, + device_map="auto", +) + +prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' +prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + +# Remove the pipeline so we can re-load the pipeline with the unet +del text_encoder +del pipe +gc.collect() +torch.cuda.empty_cache() + +pipe = IFPipeline.from_pretrained( + "DeepFloyd/IF-I-IF-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto" +) + +generator = torch.Generator().manual_seed(0) +image = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + output_type="pt", + generator=generator, +).images + +pt_to_pil(image)[0].save("./if_stage_I.png") + +# Remove the pipeline so we can load the super-resolution pipeline +del pipe +gc.collect() +torch.cuda.empty_cache() + +# First super resolution + +pipe = IFSuperResolutionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto" +) + +generator = torch.Generator().manual_seed(0) +image = pipe( + image=image, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + output_type="pt", + generator=generator, +).images + +pt_to_pil(image)[0].save("./if_stage_II.png") +``` + + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - | +| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - | +| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - | +| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - | +| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - | +| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - | + +## IFPipeline +[[autodoc]] IFPipeline + - all + - __call__ + +## IFSuperResolutionPipeline +[[autodoc]] IFSuperResolutionPipeline + - all + - __call__ + +## IFImg2ImgPipeline +[[autodoc]] IFImg2ImgPipeline + - all + - __call__ + +## IFImg2ImgSuperResolutionPipeline +[[autodoc]] IFImg2ImgSuperResolutionPipeline + - all + - __call__ + +## IFInpaintingPipeline +[[autodoc]] IFInpaintingPipeline + - all + - __call__ + +## IFInpaintingSuperResolutionPipeline +[[autodoc]] IFInpaintingSuperResolutionPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 3c5331955513..91716784f8fe 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -51,6 +51,9 @@ available a colab notebook to directly try them out. | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | +| [if](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) +| [if_img2img](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) +| [if_inpainting](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) | [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 10a237f29278..46a985ac2f8d 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -58,6 +58,9 @@ The library has three main components: | [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | +| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation | +| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation | +| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | diff --git a/scripts/convert_if.py b/scripts/convert_if.py new file mode 100644 index 000000000000..66d7f694c8e1 --- /dev/null +++ b/scripts/convert_if.py @@ -0,0 +1,1257 @@ +import argparse +import inspect +import os + +import numpy as np +import torch +from torch.nn import functional as F +from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer + +from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel +from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker + + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`." + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", required=False, default=None, type=str) + + parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str) + + parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str) + + parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file") + + parser.add_argument( + "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file" + ) + + parser.add_argument( + "--unet_checkpoint_path_stage_2", + required=False, + default=None, + type=str, + help="Path to stage 2 unet checkpoint file", + ) + + parser.add_argument( + "--unet_checkpoint_path_stage_3", + required=False, + default=None, + type=str, + help="Path to stage 3 unet checkpoint file", + ) + + parser.add_argument("--p_head_path", type=str, required=True) + + parser.add_argument("--w_head_path", type=str, required=True) + + args = parser.parse_args() + + return args + + +def main(args): + tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") + text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl") + + feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path) + + if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None: + convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args) + + if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None: + convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2) + + if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None: + convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3) + + +def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args): + unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path) + + scheduler = DDPMScheduler( + variance_type="learned_range", + beta_schedule="squaredcos_cap_v2", + prediction_type="epsilon", + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.5, + ) + + pipe = IFPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=True, + ) + + pipe.save_pretrained(args.dump_path) + + +def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage): + if stage == 2: + unet_checkpoint_path = args.unet_checkpoint_path_stage_2 + sample_size = None + dump_path = args.dump_path_stage_2 + elif stage == 3: + unet_checkpoint_path = args.unet_checkpoint_path_stage_3 + sample_size = 1024 + dump_path = args.dump_path_stage_3 + else: + assert False + + unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size) + + image_noising_scheduler = DDPMScheduler( + beta_schedule="squaredcos_cap_v2", + ) + + scheduler = DDPMScheduler( + variance_type="learned_range", + beta_schedule="squaredcos_cap_v2", + prediction_type="epsilon", + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.0, + ) + + pipe = IFSuperResolutionPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=True, + ) + + pipe.save_pretrained(dump_path) + + +def get_stage_1_unet(unet_config, unet_checkpoint_path): + original_unet_config = OmegaConf.load(unet_config) + original_unet_config = original_unet_config.params + + unet_diffusers_config = create_unet_diffusers_config(original_unet_config) + + unet = UNet2DConditionModel(**unet_diffusers_config) + + device = "cuda" if torch.cuda.is_available() else "cpu" + unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path + ) + + unet.load_state_dict(converted_unet_checkpoint) + + return unet + + +def convert_safety_checker(p_head_path, w_head_path): + state_dict = {} + + # p head + + p_head = np.load(p_head_path) + + p_head_weights = p_head["weights"] + p_head_weights = torch.from_numpy(p_head_weights) + p_head_weights = p_head_weights.unsqueeze(0) + + p_head_biases = p_head["biases"] + p_head_biases = torch.from_numpy(p_head_biases) + p_head_biases = p_head_biases.unsqueeze(0) + + state_dict["p_head.weight"] = p_head_weights + state_dict["p_head.bias"] = p_head_biases + + # w head + + w_head = np.load(w_head_path) + + w_head_weights = w_head["weights"] + w_head_weights = torch.from_numpy(w_head_weights) + w_head_weights = w_head_weights.unsqueeze(0) + + w_head_biases = w_head["biases"] + w_head_biases = torch.from_numpy(w_head_biases) + w_head_biases = w_head_biases.unsqueeze(0) + + state_dict["w_head.weight"] = w_head_weights + state_dict["w_head.bias"] = w_head_biases + + # vision model + + vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + vision_model_state_dict = vision_model.state_dict() + + for key, value in vision_model_state_dict.items(): + key = f"vision_model.{key}" + state_dict[key] = value + + # full model + + config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14") + safety_checker = IFSafetyChecker(config) + + safety_checker.load_state_dict(state_dict) + + return safety_checker + + +def create_unet_diffusers_config(original_unet_config, class_embed_type=None): + attention_resolutions = parse_list(original_unet_config.attention_resolutions) + attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] + + channel_mult = parse_list(original_unet_config.channel_mult) + block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] + + down_block_types = [] + resolution = 1 + + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnDownBlock2D" + elif original_unet_config.resblock_updown: + block_type = "ResnetDownsampleBlock2D" + else: + block_type = "DownBlock2D" + + down_block_types.append(block_type) + + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnUpBlock2D" + elif original_unet_config.resblock_updown: + block_type = "ResnetUpsampleBlock2D" + else: + block_type = "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + head_dim = original_unet_config.num_head_channels + + use_linear_projection = ( + original_unet_config.use_linear_in_transformer + if "use_linear_in_transformer" in original_unet_config + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + projection_class_embeddings_input_dim = None + + if class_embed_type is None: + if "num_classes" in original_unet_config: + if original_unet_config.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in original_unet_config + projection_class_embeddings_input_dim = original_unet_config.adm_in_channels + else: + raise NotImplementedError( + f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" + ) + + config = { + "sample_size": original_unet_config.image_size, + "in_channels": original_unet_config.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": original_unet_config.num_res_blocks, + "cross_attention_dim": original_unet_config.encoder_channels, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "out_channels": original_unet_config.out_channels, + "up_block_types": tuple(up_block_types), + "upcast_attention": False, # TODO: guessing + "cross_attention_norm": "group_norm", + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "addition_embed_type": "text", + "act_fn": "gelu", + } + + if original_unet_config.use_scale_shift_norm: + config["resnet_time_scale_shift"] = "scale_shift" + + if "encoder_dim" in original_unet_config: + config["encoder_hid_dim"] = original_unet_config.encoder_dim + + return config + + +def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] in [None, "identity"]: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + + # TODO need better check than i in [4, 8, 12, 16] + block_type = config["down_block_types"][block_id] + if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [ + 4, + 8, + 12, + 16, + ]: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} + else: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + old_path = f"input_blocks.{i}.1" + new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = {"old": old_path, "new": new_path} + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + old_path = "middle_block.1" + new_path = "mid_block.attentions.0" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + # len(output_block_list) == 1 -> resnet + # len(output_block_list) == 2 -> resnet, attention + # len(output_block_list) == 3 -> resnet, attention, upscale resnet + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + old_path = f"output_blocks.{i}.1" + new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = { + "old": old_path, + "new": new_path, + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(output_block_list) == 3: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if "encoder_proj.weight" in unet_state_dict: + new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight") + new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias") + + if "encoder_pooling.0.weight" in unet_state_dict: + new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight") + new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias") + + new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop( + "encoder_pooling.1.positional_embedding" + ) + new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight") + new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias") + new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight") + new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias") + new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight") + new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias") + + new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight") + new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias") + + new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight") + new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias") + + return new_checkpoint + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + if "qkv" in new_item: + continue + + if "encoder_kv" in new_item: + continue + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight") + new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config): + qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight") + qkv_weight = qkv_weight[:, :, 0] + + qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias") + + is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"] + + split = 1 if is_cross_attn_only else 3 + + weights, bias = split_attentions( + weight=qkv_weight, + bias=qkv_bias, + split=split, + chunk_size=config["attention_head_dim"], + ) + + if is_cross_attn_only: + query_weight, q_bias = weights, bias + new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0] + new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0] + else: + [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias + new_checkpoint[f"{new_path}.to_q.weight"] = query_weight + new_checkpoint[f"{new_path}.to_q.bias"] = q_bias + new_checkpoint[f"{new_path}.to_k.weight"] = key_weight + new_checkpoint[f"{new_path}.to_k.bias"] = k_bias + new_checkpoint[f"{new_path}.to_v.weight"] = value_weight + new_checkpoint[f"{new_path}.to_v.bias"] = v_bias + + encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight") + encoder_kv_weight = encoder_kv_weight[:, :, 0] + + encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias") + + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=encoder_kv_weight, + bias=encoder_kv_bias, + split=2, + chunk_size=config["attention_head_dim"], + ) + + new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight + new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias + new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight + new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias + + +def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + for path in paths: + new_path = path["new"] + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +def parse_list(value): + if isinstance(value, str): + value = value.split(",") + value = [int(v) for v in value] + elif isinstance(value, list): + pass + else: + raise ValueError(f"Can't parse list for type: {type(value)}") + + return value + + +# below is copy and pasted from original convert_if_stage_2.py script + + +def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): + orig_path = unet_checkpoint_path + + original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml")) + original_unet_config = original_unet_config.params + + unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) + unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int( + original_unet_config.channel_mult.split(",")[-1] + ) + if original_unet_config.encoder_dim != original_unet_config.encoder_channels: + unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim + unet_diffusers_config["class_embed_type"] = "timestep" + unet_diffusers_config["addition_embed_type"] = "text" + + unet_diffusers_config["time_embedding_act_fn"] = "gelu" + unet_diffusers_config["resnet_skip_time_act"] = True + unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 + unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 + unet_diffusers_config["only_cross_attention"] = ( + bool(original_unet_config.disable_self_attentions) + if ( + "disable_self_attentions" in original_unet_config + and isinstance(original_unet_config.disable_self_attentions, int) + ) + else True + ) + + if sample_size is None: + unet_diffusers_config["sample_size"] = original_unet_config.image_size + else: + # The second upscaler unet's sample size is incorrectly specified + # in the config and is instead hardcoded in source + unet_diffusers_config["sample_size"] = sample_size + + unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu") + + if verify_param_count: + # check that architecture matches - is a bit slow + verify_param_count(orig_path, unet_diffusers_config) + + converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint( + unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path + ) + converted_keys = converted_unet_checkpoint.keys() + + model = UNet2DConditionModel(**unet_diffusers_config) + expected_weights = model.state_dict().keys() + + diff_c_e = set(converted_keys) - set(expected_weights) + diff_e_c = set(expected_weights) - set(converted_keys) + + assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}" + assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}" + + model.load_state_dict(converted_unet_checkpoint) + + return model + + +def superres_create_unet_diffusers_config(original_unet_config): + attention_resolutions = parse_list(original_unet_config.attention_resolutions) + attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] + + channel_mult = parse_list(original_unet_config.channel_mult) + block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] + + down_block_types = [] + resolution = 1 + + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnDownBlock2D" + elif original_unet_config.resblock_updown: + block_type = "ResnetDownsampleBlock2D" + else: + block_type = "DownBlock2D" + + down_block_types.append(block_type) + + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + if resolution in attention_resolutions: + block_type = "SimpleCrossAttnUpBlock2D" + elif original_unet_config.resblock_updown: + block_type = "ResnetUpsampleBlock2D" + else: + block_type = "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + head_dim = original_unet_config.num_head_channels + use_linear_projection = ( + original_unet_config.use_linear_in_transformer + if "use_linear_in_transformer" in original_unet_config + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in original_unet_config: + if original_unet_config.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in original_unet_config + projection_class_embeddings_input_dim = original_unet_config.adm_in_channels + else: + raise NotImplementedError( + f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" + ) + + config = { + "in_channels": original_unet_config.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": tuple(original_unet_config.num_res_blocks), + "cross_attention_dim": original_unet_config.encoder_channels, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "out_channels": original_unet_config.out_channels, + "up_block_types": tuple(up_block_types), + "upcast_attention": False, # TODO: guessing + "cross_attention_norm": "group_norm", + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "act_fn": "gelu", + } + + if original_unet_config.use_scale_shift_norm: + config["resnet_time_scale_shift"] = "scale_shift" + + return config + + +def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if "encoder_proj.weight" in unet_state_dict: + new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"] + new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"] + + if "encoder_pooling.0.weight" in unet_state_dict: + mapping = { + "encoder_pooling.0": "add_embedding.norm1", + "encoder_pooling.1": "add_embedding.pool", + "encoder_pooling.2": "add_embedding.proj", + "encoder_pooling.3": "add_embedding.norm2", + } + for key in unet_state_dict.keys(): + if key.startswith("encoder_pooling"): + prefix = key[: len("encoder_pooling.0")] + new_key = key.replace(prefix, mapping[prefix]) + new_checkpoint[new_key] = unet_state_dict[key] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + if not isinstance(config["layers_per_block"], int): + layers_per_block_list = [e + 1 for e in config["layers_per_block"]] + layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) + downsampler_ids = layers_per_block_cumsum + else: + # TODO need better check than i in [4, 8, 12, 16] + downsampler_ids = [4, 8, 12, 16] + + for i in range(1, num_input_blocks): + if isinstance(config["layers_per_block"], int): + layers_per_block = config["layers_per_block"] + block_id = (i - 1) // (layers_per_block + 1) + layer_in_block_id = (i - 1) % (layers_per_block + 1) + else: + block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n) + passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 + layer_in_block_id = (i - 1) - passed_blocks + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + + block_type = config["down_block_types"][block_id] + if ( + block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D" + ) and i in downsampler_ids: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} + else: + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + old_path = f"input_blocks.{i}.1" + new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(attentions) + meta_path = {"old": old_path, "new": new_path} + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + old_path = "middle_block.1" + new_path = "mid_block.attentions.0" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + if not isinstance(config["layers_per_block"], int): + layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]])) + layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) + + for i in range(num_output_blocks): + if isinstance(config["layers_per_block"], int): + layers_per_block = config["layers_per_block"] + block_id = i // (layers_per_block + 1) + layer_in_block_id = i % (layers_per_block + 1) + else: + block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n) + passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 + layer_in_block_id = i - passed_blocks + + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + # len(output_block_list) == 1 -> resnet + # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet + # len(output_block_list) == 3 -> resnet, attention, upscale resnet + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + + has_attention = True + if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]): + has_attention = False + + maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # this layer was no attention + has_attention = False + maybe_attentions = [] + + if has_attention: + old_path = f"output_blocks.{i}.1" + new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" + + assign_attention_to_checkpoint( + new_checkpoint=new_checkpoint, + unet_state_dict=unet_state_dict, + old_path=old_path, + new_path=new_path, + config=config, + ) + + paths = renew_attention_paths(maybe_attentions) + meta_path = { + "old": old_path, + "new": new_path, + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0): + layer_id = len(output_block_list) - 1 + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key] + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def verify_param_count(orig_path, unet_diffusers_config): + if "-II-" in orig_path: + from deepfloyd_if.modules import IFStageII + + if_II = IFStageII(device="cpu", dir_or_name=orig_path) + elif "-III-" in orig_path: + from deepfloyd_if.modules import IFStageIII + + if_II = IFStageIII(device="cpu", dir_or_name=orig_path) + else: + assert f"Weird name. Should have -II- or -III- in path: {orig_path}" + + unet = UNet2DConditionModel(**unet_diffusers_config) + + # in params + assert_param_count(unet.time_embedding, if_II.model.time_embed) + assert_param_count(unet.conv_in, if_II.model.input_blocks[:1]) + + # downblocks + assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4]) + assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7]) + assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11]) + + if "-II-" in orig_path: + assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17]) + assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:]) + if "-III-" in orig_path: + assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15]) + assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20]) + assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:]) + + # mid block + assert_param_count(unet.mid_block, if_II.model.middle_block) + + # up block + if "-II-" in orig_path: + assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6]) + assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12]) + assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16]) + assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19]) + assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:]) + if "-III-" in orig_path: + assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5]) + assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10]) + assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14]) + assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18]) + assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21]) + assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24]) + + # out params + assert_param_count(unet.conv_norm_out, if_II.model.out[0]) + assert_param_count(unet.conv_out, if_II.model.out[2]) + + # make sure all model architecture has same param count + assert_param_count(unet, if_II.model) + + +def assert_param_count(model_1, model_2): + count_1 = sum(p.numel() for p in model_1.parameters()) + count_2 = sum(p.numel() for p in model_2.parameters()) + assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" + + +def superres_check_against_original(dump_path, unet_checkpoint_path): + model_path = dump_path + model = UNet2DConditionModel.from_pretrained(model_path) + model.to("cuda") + orig_path = unet_checkpoint_path + + if "-II-" in orig_path: + from deepfloyd_if.modules import IFStageII + + if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model + elif "-III-" in orig_path: + from deepfloyd_if.modules import IFStageIII + + if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model + + batch_size = 1 + channels = model.in_channels // 2 + height = model.sample_size + width = model.sample_size + height = 1024 + width = 1024 + + torch.manual_seed(0) + + latents = torch.randn((batch_size, channels, height, width), device=model.device) + image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype) + t = torch.tensor([5], device=model.device).to(model.dtype) + + seq_len = 64 + encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to( + model.dtype + ) + + fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype) + + with torch.no_grad(): + out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states) + + if_II_model.to("cpu") + del if_II_model + import gc + + torch.cuda.empty_cache() + gc.collect() + print(50 * "=") + + with torch.no_grad(): + noise_pred = model( + sample=latent_model_input, + encoder_hidden_states=encoder_hidden_states, + class_labels=fake_class_labels, + timestep=t, + ).sample + + print("Out shape", noise_pred.shape) + print("Diff", (out - noise_pred).abs().sum()) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 40029fcecfd1..e9d12bdb7cca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -114,6 +114,12 @@ AltDiffusionPipeline, AudioLDMPipeline, CycleDiffusionPipeline, + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 772e119fbe97..af639de306ee 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -109,6 +109,7 @@ def register_to_config(self, **kwargs): # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) + if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: @@ -550,6 +551,9 @@ def to_json_saveable(value): return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} + # Don't save "_ignore_files" + config_dict.pop("_ignore_files", None) + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d12e75344ba1..fa88bce305e6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -377,3 +377,69 @@ def forward(self, timestep, class_labels, hidden_dtype=None): conditioning = timesteps_emb + class_labels # (N, D) return conditioning + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5363e6330623..521e99fdd69c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import inspect +import itertools import os from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union @@ -60,7 +61,8 @@ def get_parameter_device(parameter: torch.nn.Module): try: - return next(parameter.parameters()).device + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 @@ -75,7 +77,8 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module): try: - return next(parameter.parameters()).dtype + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b4997a257643..38e0fa3b5b2e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -23,7 +23,7 @@ from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to None): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, default to `None`): + An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, default to `None`): Optional activation function to use on the time embeddings only one time before they as passed to the rest of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. @@ -155,12 +160,14 @@ def __init__( dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, @@ -170,6 +177,7 @@ def __init__( class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, ): super().__init__() @@ -214,7 +222,7 @@ def __init__( # time if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( @@ -222,7 +230,7 @@ def __init__( ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] @@ -273,6 +281,18 @@ def __init__( else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") + if time_embedding_act_fn is None: self.time_embed_act = None elif time_embedding_act_fn == "swish": @@ -684,6 +704,10 @@ def forward( else: emb = emb + class_emb + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + emb = emb + aug_emb + if self.time_embed_act is not None: emb = self.time_embed_act(emb) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 602cf028e2e9..10da653a1377 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -44,6 +44,14 @@ else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audioldm import AudioLDMPipeline + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline diff --git a/src/diffusers/pipelines/deepfloyd_if/__init__.py b/src/diffusers/pipelines/deepfloyd_if/__init__.py new file mode 100644 index 000000000000..93414f20e733 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/__init__.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL + +from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available +from .timesteps import ( + fast27_timesteps, + smart27_timesteps, + smart50_timesteps, + smart100_timesteps, + smart185_timesteps, + super27_timesteps, + super40_timesteps, + super100_timesteps, +) + + +@dataclass +class IFPipelineOutput(BaseOutput): + """ + Args: + Output class for Stable Diffusion pipelines. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content or a watermark. `None` if safety checking could not be performed. + watermark_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety + checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_detected: Optional[List[bool]] + watermark_detected: Optional[List[bool]] + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_if import IFPipeline + from .pipeline_if_img2img import IFImg2ImgPipeline + from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline + from .pipeline_if_inpainting import IFInpaintingPipeline + from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline + from .pipeline_if_superresolution import IFSuperResolutionPipeline + from .safety_checker import IFSafetyChecker + from .watermark import IFWatermarker diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py new file mode 100644 index 000000000000..a76e51a3ffe9 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -0,0 +1,854 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> safety_modules = { + ... "feature_extractor": pipe.feature_extractor, + ... "safety_checker": pipe.safety_checker, + ... "watermarker": pipe.watermarker, + ... } + >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 + ... ) + >>> super_res_2_pipe.enable_model_cpu_offload() + + >>> image = super_res_2_pipe( + ... prompt=prompt, + ... image=image, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + self.unet.config.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py new file mode 100644 index 000000000000..a31748450d4b --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -0,0 +1,979 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-IF-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.7, + num_inference_steps: int = 80, + timesteps: List[int] = None, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py new file mode 100644 index 000000000000..21e280654cf5 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -0,0 +1,1097 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-IF-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + + image = np.stack(image, axis=0) # to np + torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py new file mode 100644 index 000000000000..95eba1cc7d24 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -0,0 +1,1098 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFInpaintingPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py new file mode 100644 index 000000000000..4eb0bf300fa5 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -0,0 +1,1208 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` + """ + + +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + + image = np.stack(image, axis=0) # to np + torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], + original_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + mask_image: Union[ + PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] + ] = None, + strength: float = 0.8, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 0, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.FloatTensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 0): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. prepare mask image + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + mask_image, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py new file mode 100644 index 000000000000..bb1d4ee4ba66 --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -0,0 +1,947 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_accelerate_available, + is_accelerate_version, + is_bs4_available, + is_ftfy_available, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFSuperResolutionPipeline(DiffusionPipeline): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: Optional[CLIPImageProcessor] + safety_checker: Optional[IFSafetyChecker] + + watermarker: Optional[IFWatermarker] + + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: Optional[IFSafetyChecker], + feature_extractor: Optional[CLIPImageProcessor], + watermarker: Optional[IFWatermarker], + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet.config.in_channels != 6: + logger.warn( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.text_encoder, + self.unet, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + + if self.text_encoder is not None: + _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) + + # Accelerate will move the next model to the device _before_ calling the offload hook of the + # previous model. This will cause both models to be present on the device at the same time. + # IF uses T5 for its text encoder which is really large. We can manually call the offload + # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to + # the GPU. + self.text_encoder_offload_hook = hook + + _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) + + # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet + self.unet_offload_hook = hook + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet, self.safety_checker]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + + image = np.stack(image, axis=0) # to np + torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = self.unet.config.sample_size + width = self.unet.config.sample_size + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + image = intermediate_images + + if output_type == "pil": + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL + image = self.numpy_to_pil(image) + + # 12. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/src/diffusers/pipelines/deepfloyd_if/safety_checker.py b/src/diffusers/pipelines/deepfloyd_if/safety_checker.py new file mode 100644 index 000000000000..8ffeed580bbe --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/safety_checker.py @@ -0,0 +1,59 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class IFSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModelWithProjection(config.vision_config) + + self.p_head = nn.Linear(config.vision_config.projection_dim, 1) + self.w_head = nn.Linear(config.vision_config.projection_dim, 1) + + @torch.no_grad() + def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): + image_embeds = self.vision_model(clip_input)[0] + + nsfw_detected = self.p_head(image_embeds) + nsfw_detected = nsfw_detected.flatten() + nsfw_detected = nsfw_detected > p_threshold + nsfw_detected = nsfw_detected.tolist() + + if any(nsfw_detected): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, nsfw_detected_ in enumerate(nsfw_detected): + if nsfw_detected_: + images[idx] = np.zeros(images[idx].shape) + + watermark_detected = self.w_head(image_embeds) + watermark_detected = watermark_detected.flatten() + watermark_detected = watermark_detected > w_threshold + watermark_detected = watermark_detected.tolist() + + if any(watermark_detected): + logger.warning( + "Potential watermarked content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, watermark_detected_ in enumerate(watermark_detected): + if watermark_detected_: + images[idx] = np.zeros(images[idx].shape) + + return images, nsfw_detected, watermark_detected diff --git a/src/diffusers/pipelines/deepfloyd_if/timesteps.py b/src/diffusers/pipelines/deepfloyd_if/timesteps.py new file mode 100644 index 000000000000..d44285c017bb --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/timesteps.py @@ -0,0 +1,579 @@ +fast27_timesteps = [ + 999, + 800, + 799, + 600, + 599, + 500, + 400, + 399, + 377, + 355, + 333, + 311, + 288, + 266, + 244, + 222, + 200, + 199, + 177, + 155, + 133, + 111, + 88, + 66, + 44, + 22, + 0, +] + +smart27_timesteps = [ + 999, + 976, + 952, + 928, + 905, + 882, + 858, + 857, + 810, + 762, + 715, + 714, + 572, + 429, + 428, + 286, + 285, + 238, + 190, + 143, + 142, + 118, + 95, + 71, + 47, + 24, + 0, +] + +smart50_timesteps = [ + 999, + 988, + 977, + 966, + 955, + 944, + 933, + 922, + 911, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 350, + 300, + 299, + 266, + 233, + 200, + 199, + 179, + 159, + 140, + 120, + 100, + 99, + 88, + 77, + 66, + 55, + 44, + 33, + 22, + 11, + 0, +] + +smart100_timesteps = [ + 999, + 995, + 992, + 989, + 985, + 981, + 978, + 975, + 971, + 967, + 964, + 961, + 957, + 956, + 951, + 947, + 942, + 937, + 933, + 928, + 923, + 919, + 914, + 913, + 908, + 903, + 897, + 892, + 887, + 881, + 876, + 871, + 870, + 864, + 858, + 852, + 846, + 840, + 834, + 828, + 827, + 820, + 813, + 806, + 799, + 792, + 785, + 784, + 777, + 770, + 763, + 756, + 749, + 742, + 741, + 733, + 724, + 716, + 707, + 699, + 698, + 688, + 677, + 666, + 656, + 655, + 645, + 634, + 623, + 613, + 612, + 598, + 584, + 570, + 569, + 555, + 541, + 527, + 526, + 505, + 484, + 483, + 462, + 440, + 439, + 396, + 395, + 352, + 351, + 308, + 307, + 264, + 263, + 220, + 219, + 176, + 132, + 88, + 44, + 0, +] + +smart185_timesteps = [ + 999, + 997, + 995, + 992, + 990, + 988, + 986, + 984, + 981, + 979, + 977, + 975, + 972, + 970, + 968, + 966, + 964, + 961, + 959, + 957, + 956, + 954, + 951, + 949, + 946, + 944, + 941, + 939, + 936, + 934, + 931, + 929, + 926, + 924, + 921, + 919, + 916, + 914, + 913, + 910, + 907, + 905, + 902, + 899, + 896, + 893, + 891, + 888, + 885, + 882, + 879, + 877, + 874, + 871, + 870, + 867, + 864, + 861, + 858, + 855, + 852, + 849, + 846, + 843, + 840, + 837, + 834, + 831, + 828, + 827, + 824, + 821, + 817, + 814, + 811, + 808, + 804, + 801, + 798, + 795, + 791, + 788, + 785, + 784, + 780, + 777, + 774, + 770, + 766, + 763, + 760, + 756, + 752, + 749, + 746, + 742, + 741, + 737, + 733, + 730, + 726, + 722, + 718, + 714, + 710, + 707, + 703, + 699, + 698, + 694, + 690, + 685, + 681, + 677, + 673, + 669, + 664, + 660, + 656, + 655, + 650, + 646, + 641, + 636, + 632, + 627, + 622, + 618, + 613, + 612, + 607, + 602, + 596, + 591, + 586, + 580, + 575, + 570, + 569, + 563, + 557, + 551, + 545, + 539, + 533, + 527, + 526, + 519, + 512, + 505, + 498, + 491, + 484, + 483, + 474, + 466, + 457, + 449, + 440, + 439, + 428, + 418, + 407, + 396, + 395, + 381, + 366, + 352, + 351, + 330, + 308, + 307, + 286, + 264, + 263, + 242, + 220, + 219, + 176, + 175, + 132, + 131, + 88, + 44, + 0, +] + +super27_timesteps = [ + 999, + 991, + 982, + 974, + 966, + 958, + 950, + 941, + 933, + 925, + 916, + 908, + 900, + 899, + 874, + 850, + 825, + 800, + 799, + 700, + 600, + 500, + 400, + 300, + 200, + 100, + 0, +] + +super40_timesteps = [ + 999, + 992, + 985, + 978, + 971, + 964, + 957, + 949, + 942, + 935, + 928, + 921, + 914, + 907, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 300, + 299, + 200, + 199, + 100, + 99, + 0, +] + +super100_timesteps = [ + 999, + 996, + 992, + 989, + 985, + 982, + 979, + 975, + 972, + 968, + 965, + 961, + 958, + 955, + 951, + 948, + 944, + 941, + 938, + 934, + 931, + 927, + 924, + 920, + 917, + 914, + 910, + 907, + 903, + 900, + 899, + 891, + 884, + 876, + 869, + 861, + 853, + 846, + 838, + 830, + 823, + 815, + 808, + 800, + 799, + 788, + 777, + 766, + 755, + 744, + 733, + 722, + 711, + 700, + 699, + 688, + 677, + 666, + 655, + 644, + 633, + 622, + 611, + 600, + 599, + 585, + 571, + 557, + 542, + 528, + 514, + 500, + 499, + 485, + 471, + 457, + 442, + 428, + 414, + 400, + 399, + 379, + 359, + 340, + 320, + 300, + 299, + 279, + 259, + 240, + 220, + 200, + 199, + 166, + 133, + 100, + 99, + 66, + 33, + 0, +] diff --git a/src/diffusers/pipelines/deepfloyd_if/watermark.py b/src/diffusers/pipelines/deepfloyd_if/watermark.py new file mode 100644 index 000000000000..db33dec0ef9a --- /dev/null +++ b/src/diffusers/pipelines/deepfloyd_if/watermark.py @@ -0,0 +1,46 @@ +from typing import List + +import PIL +import torch +from PIL import Image + +from ...configuration_utils import ConfigMixin +from ...models.modeling_utils import ModelMixin +from ...utils import PIL_INTERPOLATION + + +class IFWatermarker(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) + self.watermark_image_as_pil = None + + def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): + # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + + h = images[0].height + w = images[0].width + + sample_size = sample_size or h + + coef = min(h / sample_size, w / sample_size) + img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) + + S1, S2 = 1024**2, img_w * img_h + K = (S2 / S1) ** 0.5 + wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) + + if self.watermark_image_as_pil is None: + watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() + watermark_image = Image.fromarray(watermark_image, mode="RGBA") + self.watermark_image_as_pil = watermark_image + + wm_img = self.watermark_image_as_pil.resize( + (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None + ) + + for pil_img in images: + pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) + + return images diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2d61f1a3700f..8c028b64a8c8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -30,7 +30,6 @@ import torch from huggingface_hub import hf_hub_download, model_info, snapshot_download from packaging import version -from PIL import Image from tqdm.auto import tqdm import diffusers @@ -56,6 +55,7 @@ is_torch_version, is_transformers_available, logging, + numpy_to_pil, ) @@ -623,7 +623,9 @@ def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False - return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + return hasattr(module, "_hf_hook") and not isinstance( + module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook) + ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): @@ -653,7 +655,20 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: - module.to(torch_device, torch_dtype) + is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + + if is_loaded_in_8bit and torch_dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + ) + + if is_loaded_in_8bit and torch_device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + ) + else: + module.to(torch_device, torch_dtype) + if ( module.dtype == torch.float16 and str(torch_device) in ["cpu"] @@ -887,6 +902,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config_dict = cls.load_config(cached_folder) + # pop out "_ignore_files" as it is only needed for download + config_dict.pop("_ignore_files", None) + # 2. Define which model components should load variants # We retrieve the information by matching whether variant # model checkpoints exist in the subfolders @@ -1204,12 +1222,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) config_dict = cls._dict_from_json_file(config_file) + + ignore_filenames = config_dict.pop("_ignore_files", []) + # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + # remove ignored filenames + model_filenames = set(model_filenames) - set(ignore_filenames) + variant_filenames = set(variant_filenames) - set(ignore_filenames) + # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version @@ -1370,16 +1395,7 @@ def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images + return numpy_to_pil(images) def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 8db19c2b9109..56681391aeeb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -56,7 +56,18 @@ def __init__( scheduler: Any, max_noise_level: int = 350, ): - super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level) + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + watermarker=None, + max_noise_level=max_noise_level, + ) def __call__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 693208b18cdd..45b26de284af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -13,18 +13,19 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import numpy as np import PIL import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -76,6 +77,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ + _optional_components = ["watermarker", "safety_checker", "feature_extractor"] def __init__( self, @@ -85,12 +87,16 @@ def __init__( unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, scheduler: KarrasDiffusionSchedulers, + safety_checker: Optional[Any] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + watermarker: Optional[Any] = None, max_noise_level: int = 350, ): super().__init__() - if hasattr(vae, "config"): - # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate is_vae_scaling_factor_set_to_0_08333 = ( hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 ) @@ -113,6 +119,9 @@ def __init__( unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, + safety_checker=safety_checker, + watermarker=watermarker, + feature_extractor=feature_extractor, ) self.register_to_config(max_noise_level=max_noise_level) @@ -178,6 +187,23 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, @@ -678,10 +704,19 @@ def __call__( self.final_offload_hook.offload() # 11. Convert to PIL + # has_nsfw_concept = False if output_type == "pil": + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.numpy_to_pil(image) + # 11. Apply watermark + if self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + else: + has_nsfw_concept = None + if not return_dict: - return (image,) + return (image, has_nsfw_concept) - return ImagePipelineOutput(images=image) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 2a7b80d01da7..57e1abc7315b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -15,7 +15,7 @@ AttnProcessor, ) from ...models.dual_transformer_2d import DualTransformer2DModel -from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import logging @@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to None): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to None): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, default to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, default to `None`): + An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, default to `None`): Optional activation function to use on the time embeddings only one time before they as passed to the rest of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. @@ -246,12 +251,14 @@ def __init__( dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, @@ -261,6 +268,7 @@ def __init__( class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, ): super().__init__() @@ -311,7 +319,7 @@ def __init__( # time if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( @@ -319,7 +327,7 @@ def __init__( ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] @@ -370,6 +378,18 @@ def __init__( else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") + if time_embedding_act_fn is None: self.time_embed_act = None elif time_embedding_act_fn == "swish": @@ -781,6 +801,10 @@ def forward( else: emb = emb + class_emb + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + emb = emb + aug_emb + if self.time_embed_act is not None: emb = self.time_embed_act(emb) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c717d722f84c..1b8eca050c9e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -44,6 +44,7 @@ http_user_agent, ) from .import_utils import ( + BACKENDS_MAPPING, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, USE_JAX, @@ -53,7 +54,9 @@ OptionalDependencyNotAvailable, is_accelerate_available, is_accelerate_version, + is_bs4_available, is_flax_available, + is_ftfy_available, is_inflect_available, is_k_diffusion_available, is_k_diffusion_version, @@ -76,7 +79,7 @@ ) from .logging import get_logger from .outputs import BaseOutput -from .pil_utils import PIL_INTERPOLATION +from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index bda56d2ae8ae..bf4fe8d87ff9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -62,6 +62,96 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class IFImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class IFSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index fd7538b1b5e9..2d90cb9747a7 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -271,6 +271,23 @@ _compel_available = False +_ftfy_available = importlib.util.find_spec("ftfy") is not None +try: + _ftfy_version = importlib_metadata.version("ftfy") + logger.debug(f"Successfully imported ftfy version {_ftfy_version}") +except importlib_metadata.PackageNotFoundError: + _ftfy_available = False + + +_bs4_available = importlib.util.find_spec("bs4") is not None +try: + # importlib metadata under different name + _bs4_version = importlib_metadata.version("beautifulsoup4") + logger.debug(f"Successfully imported ftfy version {_bs4_version}") +except importlib_metadata.PackageNotFoundError: + _bs4_available = False + + def is_torch_available(): return _torch_available @@ -347,6 +364,14 @@ def is_compel_available(): return _compel_available +def is_ftfy_available(): + return _ftfy_available + + +def is_bs4_available(): + return _bs4_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -437,8 +462,23 @@ def is_compel_available(): {0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` """ +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + BACKENDS_MAPPING = OrderedDict( [ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), @@ -454,6 +494,7 @@ def is_compel_available(): ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py index 39d0a15a4e2f..ad76a32230fb 100644 --- a/src/diffusers/utils/pil_utils.py +++ b/src/diffusers/utils/pil_utils.py @@ -1,6 +1,7 @@ import PIL.Image import PIL.ImageOps from packaging import version +from PIL import Image if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): @@ -19,3 +20,26 @@ "lanczos": PIL.Image.LANCZOS, "nearest": PIL.Image.NEAREST, } + + +def pt_to_pil(images): + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = numpy_to_pil(images) + return images + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images diff --git a/tests/pipelines/deepfloyd_if/__init__.py b/tests/pipelines/deepfloyd_if/__init__.py new file mode 100644 index 000000000000..094254a61875 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/__init__.py @@ -0,0 +1,272 @@ +import tempfile + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import DDPMScheduler, UNet2DConditionModel +from diffusers.models.attention_processor import AttnAddedKVProcessor +from diffusers.pipelines.deepfloyd_if import IFWatermarker +from diffusers.utils.testing_utils import torch_device + +from ..test_pipelines_common import to_np + + +# WARN: the hf-internal-testing/tiny-random-t5 text encoder has some non-determinism in the `save_load` tests. + + +class IFPipelineTesterMixin: + def _get_dummy_components(self): + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + unet = UNet2DConditionModel( + sample_size=32, + layers_per_block=1, + block_out_channels=[32, 64], + down_block_types=[ + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + ], + mid_block_type="UNetMidBlock2DSimpleCrossAttn", + up_block_types=["SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"], + in_channels=3, + out_channels=6, + cross_attention_dim=32, + encoder_hid_dim=32, + attention_head_dim=8, + addition_embed_type="text", + addition_embed_type_num_heads=2, + cross_attention_norm="group_norm", + resnet_time_scale_shift="scale_shift", + act_fn="gelu", + ) + unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests + + torch.manual_seed(0) + scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + beta_start=0.0001, + beta_end=0.02, + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.0, + prediction_type="epsilon", + variance_type="learned_range", + ) + + torch.manual_seed(0) + watermarker = IFWatermarker() + + return { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "watermarker": watermarker, + "safety_checker": None, + "feature_extractor": None, + } + + def _get_superresolution_dummy_components(self): + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + unet = UNet2DConditionModel( + sample_size=32, + layers_per_block=[1, 2], + block_out_channels=[32, 64], + down_block_types=[ + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + ], + mid_block_type="UNetMidBlock2DSimpleCrossAttn", + up_block_types=["SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"], + in_channels=6, + out_channels=6, + cross_attention_dim=32, + encoder_hid_dim=32, + attention_head_dim=8, + addition_embed_type="text", + addition_embed_type_num_heads=2, + cross_attention_norm="group_norm", + resnet_time_scale_shift="scale_shift", + act_fn="gelu", + class_embed_type="timestep", + mid_block_scale_factor=1.414, + time_embedding_act_fn="gelu", + time_embedding_dim=32, + ) + unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests + + torch.manual_seed(0) + scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + beta_start=0.0001, + beta_end=0.02, + thresholding=True, + dynamic_thresholding_ratio=0.95, + sample_max_value=1.0, + prediction_type="epsilon", + variance_type="learned_range", + ) + + torch.manual_seed(0) + image_noising_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + beta_start=0.0001, + beta_end=0.02, + ) + + torch.manual_seed(0) + watermarker = IFWatermarker() + + return { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "image_noising_scheduler": image_noising_scheduler, + "watermarker": watermarker, + "safety_checker": None, + "feature_extractor": None, + } + + # this test is modified from the base class because if pipelines set the text encoder + # as optional with the intention that the user is allowed to encode the prompt once + # and then pass the embeddings directly to the pipeline. The base class test uses + # the unmodified arguments from `self.get_dummy_inputs` which will pass the unencoded + # prompt to the pipeline when the text encoder is set to None, throwing an error. + # So we make the test reflect the intended usage of setting the text encoder to None. + def _test_save_load_optional_components(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + prompt = inputs["prompt"] + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + if "image" in inputs: + image = inputs["image"] + else: + image = None + + if "mask_image" in inputs: + mask_image = inputs["mask_image"] + else: + mask_image = None + + if "original_image" in inputs: + original_image = inputs["original_image"] + else: + original_image = None + + prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt) + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + } + + if image is not None: + inputs["image"] = image + + if mask_image is not None: + inputs["mask_image"] = mask_image + + if original_image is not None: + inputs["original_image"] = original_image + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + pipe_loaded.unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + + generator = inputs["generator"] + num_inference_steps = inputs["num_inference_steps"] + output_type = inputs["output_type"] + + # inputs with prompt converted to embeddings + inputs = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + "generator": generator, + "num_inference_steps": num_inference_steps, + "output_type": output_type, + } + + if image is not None: + inputs["image"] = image + + if mask_image is not None: + inputs["mask_image"] = mask_image + + if original_image is not None: + inputs["original_image"] = original_image + + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) + + # Modified from `PipelineTesterMixin` to set the attn processor as it's not serialized. + # This should be handled in the base test and then this method can be removed. + def _test_save_load_local(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + pipe_loaded.unet.set_attn_processor(AttnAddedKVProcessor()) # For reproducibility tests + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, 1e-4) diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py new file mode 100644 index 000000000000..e2204cb601a6 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import torch + +from diffusers import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, +) +from diffusers.models.attention_processor import AttnAddedKVProcessor +from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from . import IFPipelineTesterMixin + + +@skip_mps +class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFPipeline + params = TEXT_TO_IMAGE_PARAMS - {"width", "height", "latents"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) + + +@slow +@require_torch_gpu +class IFPipelineSlowTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_all(self): + # if + + pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-IF-v1.0", variant="fp16", torch_dtype=torch.float16) + + pipe_2 = IFSuperResolutionPipeline.from_pretrained( + "DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16, text_encoder=None, tokenizer=None + ) + + # pre compute text embeddings and remove T5 to save memory + + pipe_1.text_encoder.to("cuda") + + prompt_embeds, negative_prompt_embeds = pipe_1.encode_prompt("anime turtle", device="cuda") + + del pipe_1.tokenizer + del pipe_1.text_encoder + gc.collect() + + pipe_1.tokenizer = None + pipe_1.text_encoder = None + + pipe_1.enable_model_cpu_offload() + pipe_2.enable_model_cpu_offload() + + pipe_1.unet.set_attn_processor(AttnAddedKVProcessor()) + pipe_2.unet.set_attn_processor(AttnAddedKVProcessor()) + + self._test_if(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds) + + pipe_1.remove_all_hooks() + pipe_2.remove_all_hooks() + + # img2img + + pipe_1 = IFImg2ImgPipeline(**pipe_1.components) + pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components) + + pipe_1.enable_model_cpu_offload() + pipe_2.enable_model_cpu_offload() + + pipe_1.unet.set_attn_processor(AttnAddedKVProcessor()) + pipe_2.unet.set_attn_processor(AttnAddedKVProcessor()) + + self._test_if_img2img(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds) + + pipe_1.remove_all_hooks() + pipe_2.remove_all_hooks() + + # inpainting + + pipe_1 = IFInpaintingPipeline(**pipe_1.components) + pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components) + + pipe_1.enable_model_cpu_offload() + pipe_2.enable_model_cpu_offload() + + pipe_1.unet.set_attn_processor(AttnAddedKVProcessor()) + pipe_2.unet.set_attn_processor(AttnAddedKVProcessor()) + + self._test_if_inpainting(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds) + + def _test_if(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds): + # pipeline 1 + + _start_torch_memory_measurement() + + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe_1( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_inference_steps=2, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (64, 64, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 13 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + # pipeline 2 + + _start_torch_memory_measurement() + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) + + output = pipe_2( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + generator=generator, + num_inference_steps=2, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 4 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + def _test_if_img2img(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds): + # pipeline 1 + + _start_torch_memory_measurement() + + image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) + + generator = torch.Generator(device="cpu").manual_seed(0) + + output = pipe_1( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + num_inference_steps=2, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (64, 64, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 10 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + # pipeline 2 + + _start_torch_memory_measurement() + + generator = torch.Generator(device="cpu").manual_seed(0) + + original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device) + image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) + + output = pipe_2( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + original_image=original_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 4 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + def _test_if_inpainting(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds): + # pipeline 1 + + _start_torch_memory_measurement() + + image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) + mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device) + + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe_1( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + mask_image=mask_image, + num_inference_steps=2, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (64, 64, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 10 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + # pipeline 2 + + _start_torch_memory_measurement() + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) + original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device) + mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device) + + output = pipe_2( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + mask_image=mask_image, + original_image=original_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 4 * 10**9 + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy" + ) + assert_mean_pixel_difference(image, expected_image) + + +def _start_torch_memory_measurement(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py new file mode 100644 index 000000000000..b4c99a8ab93a --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch + +from diffusers import IFImg2ImgPipeline +from diffusers.utils import floats_tensor +from diffusers.utils.testing_utils import skip_mps, torch_device + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin +from . import IFPipelineTesterMixin + + +@skip_mps +class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_float16_inference(self): + self._test_float16_inference(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py new file mode 100644 index 000000000000..626ab321f895 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch + +from diffusers import IFImg2ImgSuperResolutionPipeline +from diffusers.utils import floats_tensor +from diffusers.utils.testing_utils import skip_mps, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin +from . import IFPipelineTesterMixin + + +@skip_mps +class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFImg2ImgSuperResolutionPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"}) + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_superresolution_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + original_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = floats_tensor((1, 3, 16, 16), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "original_image": original_image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py new file mode 100644 index 000000000000..37d818c7a910 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch + +from diffusers import IFInpaintingPipeline +from diffusers.utils import floats_tensor +from diffusers.utils.testing_utils import skip_mps, torch_device + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin +from . import IFPipelineTesterMixin + + +@skip_mps +class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFInpaintingPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py new file mode 100644 index 000000000000..30062cb2f8d0 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch + +from diffusers import IFInpaintingSuperResolutionPipeline +from diffusers.utils import floats_tensor +from diffusers.utils.testing_utils import skip_mps, torch_device + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin +from . import IFPipelineTesterMixin + + +@skip_mps +class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFInpaintingSuperResolutionPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"}) + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_superresolution_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 16, 16), rng=random.Random(seed)).to(device) + original_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "original_image": original_image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py new file mode 100644 index 000000000000..14acfa5415c2 --- /dev/null +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -0,0 +1,77 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch + +from diffusers import IFSuperResolutionPipeline +from diffusers.utils import floats_tensor +from diffusers.utils.testing_utils import skip_mps, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin +from . import IFPipelineTesterMixin + + +@skip_mps +class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.TestCase): + pipeline_class = IFSuperResolutionPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} + + test_xformers_attention = False + + def get_dummy_components(self): + return self._get_superresolution_dummy_components() + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + + return inputs + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") + def test_save_load_float16(self): + # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder + self._test_save_load_float16(expected_max_diff=1e-1) + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) + + def test_save_load_local(self): + self._test_save_load_local() + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical( + expected_max_diff=1e-2, + ) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8fb79f0c4057..168ff8106c52 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -575,6 +575,19 @@ def test_text_inversion_download(self): out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) + def test_download_ignore_files(self): + # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + tmpdirname = DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files") + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a pytorch file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack + assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) + assert len(files) == 14 + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d0712bdec8f6..0278092282ba 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -339,6 +339,9 @@ def test_components_function(self): @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_float16_inference(self): + self._test_float16_inference() + + def _test_float16_inference(self, expected_max_diff=1e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device) @@ -352,10 +355,13 @@ def test_float16_inference(self): output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0] max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() - self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.") + self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): + self._test_save_load_float16() + + def _test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): if hasattr(module, "half"): @@ -384,7 +390,9 @@ def test_save_load_float16(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.") + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) def test_save_load_optional_components(self): if not hasattr(self.pipeline_class, "_optional_components"):