diff --git a/README.md b/README.md index ecbf7b802103..3ead623a2b26 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,55 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True). images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` +Diffusers also has a Image-to-Image generation pipeline with Flax/Jax +```python +import jax +import numpy as np +import jax.numpy as jnp +from flax.jax_utils import replicate +from flax.training.common_utils import shard +import requests +from io import BytesIO +from PIL import Image +from diffusers import FlaxStableDiffusionImg2ImgPipeline + +def create_key(seed=0): + return jax.random.PRNGKey(seed) +rng = create_key(0) + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +response = requests.get(url) +init_img = Image.open(BytesIO(response.content)).convert("RGB") +init_img = init_img.resize((768, 512)) + +prompts = "A fantasy landscape, trending on artstation" + +pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="flax", + dtype=jnp.bfloat16, +) + +num_samples = jax.device_count() +rng = jax.random.split(rng, jax.device_count()) +prompt_ids, processed_image = pipeline.prepare_inputs(prompt=[prompts]*num_samples, image = [init_img]*num_samples) +p_params = replicate(params) +prompt_ids = shard(prompt_ids) +processed_image = shard(processed_image) + +output = pipeline( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + strength=0.75, + num_inference_steps=50, + jit=True, + height=512, + width=768).images + +output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) +``` + ### Image-to-Image text-guided generation with Stable Diffusion The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5f028f4a000f..e315b150a5d9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -154,10 +154,11 @@ FlaxScoreSdeVeScheduler, ) + try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: - from .pipelines import FlaxStableDiffusionPipeline + from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9f7d1a05dbca..75e9efa99db0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -90,4 +90,4 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: - from .stable_diffusion import FlaxStableDiffusionPipeline + from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 88dc5dbc0754..95d1952cae10 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -98,4 +98,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py new file mode 100644 index 000000000000..7b0b35f89e00 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -0,0 +1,413 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Union + +import numpy as np + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for image-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + processed_image = [] + for img in image: + processed_image.append(preprocess(img, self.dtype)) + processed_image = jnp.array(processed_image).squeeze() + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_image + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def get_timestep_start(self, num_inference_steps, strength, scheduler_state): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + t_start = max(num_inference_steps - init_timestep + offset, 0) + + return t_start + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + strength: float = 0.8, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, + debug: bool = False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) + init_latents = 0.18215 * init_latents + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) + latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size) + init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) + latents = init_latents + + if debug: + # run with python for loop + for i in range(t_start, len(scheduler_state.timesteps)): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop( + t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state) + ) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, + strength: float = 0.8, + return_dict: bool = True, + jit: bool = False, + debug: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing an image batch, that will be used as the starting point for the process. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if jit: + image = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, + ) + else: + image = self._generate( + prompt_ids, + image, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + image_uint8_casted = (image * 255).round().astype("uint8") + num_devices, batch_size = image.shape[:2] + + image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit) + image = np.asarray(image) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + image[i] = np.asarray(image_uint8_casted[i]) + + image = image.reshape(num_devices, batch_size, height, width, 3) + else: + has_nsfw_concept = False + + if not return_dict: + return (image, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +# TODO: maybe use a config dict instead of so many static argnums +@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10)) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, +): + return pipe._generate( + prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 14830bca2898..3f27bf413977 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -4,6 +4,21 @@ from ..utils import DummyObject, requires_backends +class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + class FlaxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"]