diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx index 6d63cd5cbe5b..4ade13e67b15 100644 --- a/docs/source/api/pipelines/latent_diffusion.mdx +++ b/docs/source/api/pipelines/latent_diffusion.mdx @@ -33,6 +33,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff | Pipeline | Tasks | Colab |---|---|:---:| | [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - | +| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - | ## Examples: @@ -40,3 +41,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## LDMTextToImagePipeline [[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline - __call__ + +## LDMSuperResolutionPipeline +[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline + - __call__ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index da56dc888138..86eda7371fe9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -35,6 +35,7 @@ DDPMPipeline, KarrasVePipeline, LDMPipeline, + LDMSuperResolutionPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index eb0635f6ee07..ef4d23e5e6d0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -5,6 +5,7 @@ from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline + from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline from .repaint import RePaintPipeline diff --git a/src/diffusers/pipelines/latent_diffusion/__init__.py b/src/diffusers/pipelines/latent_diffusion/__init__.py index c481b38cf5e0..5544527ff587 100644 --- a/src/diffusers/pipelines/latent_diffusion/__init__.py +++ b/src/diffusers/pipelines/latent_diffusion/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa from ...utils import is_transformers_available +from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline if is_transformers_available(): diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 000000000000..044ff359e317 --- /dev/null +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,169 @@ +import inspect +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL + +from ...models import UNet2DModel, VQModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using Latent + + This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + init_image: Union[torch.Tensor, PIL.Image.Image], + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 100, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + r""" + Args: + init_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(init_image, PIL.Image.Image): + batch_size = 1 + elif isinstance(init_image, torch.Tensor): + batch_size = init_image.shape[0] + else: + raise ValueError( + f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}" + ) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + height, width = init_image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + if self.device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype) + latents = latents.to(self.device) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + init_image = init_image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, init_image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9d296d29977d..af2e0c7c61d6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LDMSuperResolutionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py new file mode 100644 index 000000000000..f5ec56d1bdc2 --- /dev/null +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import torch + +import PIL +from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel +from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils.testing_utils import require_torch + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_uncond_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=6, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + return model + + @property + def dummy_vq_model(self): + torch.manual_seed(0) + model = VQModel( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=3, + ) + return model + + def test_inference_superresolution(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + vqvae = self.dummy_vq_model + + ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + init_image = self.dummy_image.to(torch_device) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images + + generator = torch.manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176]) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + + +@slow +@require_torch +class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): + def test_inference_superresolution(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool.png" + ) + init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS) + + ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.7418, 0.7472, 0.7424, 0.7422, 0.7463, 0.726, 0.7382, 0.7248, 0.6828]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2