From b3f6343890531ff7a39be7708dc559901feab450 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 21 Nov 2022 10:49:06 -0500 Subject: [PATCH 01/24] add flax img2img pipeline --- .../pipeline_flax_stable_diffusion_img2img.py | 384 ++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py new file mode 100644 index 000000000000..f71629a1b494 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -0,0 +1,384 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import numpy as np + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import logging, PIL_INTERPOLATION +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def prepare_inputs(self, prompt: Union[str, List[str]], init_image: Union[Image.Image, List[Image.Image]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(init_image, (Image.Image, list)): + raise ValueError(f"init_image has to be of type `PIL.Image.Image` or list but is {type(init_image)}") + + if isinstance(init_image, Image.Image): + init_image = [init_image] + processed_images = [] + for img in init_image: + processed_images.append(preprocess(img, self.dtype)) + processed_images = jnp.array(processed_images).squeeze() + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def get_timesteps(self, num_inference_steps, strength, scheduler_state): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = scheduler_state.timesteps[t_start:] + + return timesteps + + def _generate( + self, + prompt_ids: jnp.array, + init_image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + strength: float = 0.8, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, + debug: bool = False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, init_image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed) + init_latents = 0.18215 * init_latents + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + timesteps = self.get_timesteps(num_inference_steps, strength, scheduler_state) + latent_timestep = timesteps[:1].repeat(batch_size) + init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) + latents = init_latents + + if debug: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + def __call__( + self, + prompt_ids: jnp.array, + init_images: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, + strength: float = 0.8, + return_dict: bool = True, + jit: bool = False, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if jit: + images = _p_generate( + self, prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + ) + else: + images = self._generate( + prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# TODO: maybe use a config dict instead of so many static argnums +@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7,8, 9, 10)) +def _p_generate( + pipe, prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug +): + return pipe._generate( + prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + +def preprocess(image,dtype): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 \ No newline at end of file From e637e83f06b14fb83433aa87c809686ec9b11bb7 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 21 Nov 2022 11:09:21 -0500 Subject: [PATCH 02/24] update pipeline --- .../pipeline_flax_stable_diffusion_img2img.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index f71629a1b494..9f4921bf1769 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -41,8 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): +class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -207,10 +206,11 @@ def _generate( # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, init_image, method=self.vae.encode).latent_dist - init_latents = init_latent_dist.sample(key=prng_seed) + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0,3,1,2)) init_latents = 0.18215 * init_latents latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. @@ -355,7 +355,7 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7,8, 9, 10)) +@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10)) def _p_generate( pipe, prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug ): From e80efe9dad22ccce337009b62d28d9186df846ec Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 22 Nov 2022 10:22:25 -0500 Subject: [PATCH 03/24] black format file --- .../pipeline_flax_stable_diffusion_img2img.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 9f4921bf1769..13367b30734e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -41,6 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -109,7 +110,7 @@ def __init__( def prepare_inputs(self, prompt: Union[str, List[str]], init_image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + if not isinstance(init_image, (Image.Image, list)): raise ValueError(f"init_image has to be of type `PIL.Image.Image` or list but is {type(init_image)}") @@ -168,7 +169,6 @@ def get_timesteps(self, num_inference_steps, strength, scheduler_state): offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = scheduler_state.timesteps[t_start:] @@ -206,7 +206,7 @@ def _generate( # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, init_image, method=self.vae.encode).latent_dist - init_latents = init_latent_dist.sample(key=prng_seed).transpose((0,3,1,2)) + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = 0.18215 * init_latents latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) @@ -241,7 +241,7 @@ def loop_body(step, args): scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) - + timesteps = self.get_timesteps(num_inference_steps, strength, scheduler_state) latent_timestep = timesteps[:1].repeat(batch_size) init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) @@ -322,11 +322,30 @@ def __call__( """ if jit: images = _p_generate( - self, prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + self, + prompt_ids, + init_images, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, ) else: images = self._generate( - prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + prompt_ids, + init_images, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, ) if self.safety_checker is not None: @@ -357,7 +376,17 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10)) def _p_generate( - pipe, prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + pipe, + prompt_ids, + init_images, + params, + prng_seed, + strength, + num_inference_steps, + height, + width, + guidance_scale, + debug, ): return pipe._generate( prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug @@ -375,10 +404,11 @@ def unshard(x: jnp.ndarray): rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) -def preprocess(image,dtype): + +def preprocess(image, dtype): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 \ No newline at end of file + return 2.0 * image - 1.0 From 9603d75cb14b7e89f6913840d82183166fa77299 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 22 Nov 2022 10:24:23 -0500 Subject: [PATCH 04/24] remove argg from get_timesteps --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f543d564fe84..ada428ed4050 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -391,7 +391,7 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset From a1b27ef1e4e8df8c750acb341d3749143c5243e9 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 22 Nov 2022 11:52:05 -0500 Subject: [PATCH 05/24] update get_timesteps --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ada428ed4050..b59332c595d7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -530,7 +530,7 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps = self.get_timesteps(num_inference_steps, strength) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables From 7122386e2c45f62dbddf71cf76e1d2a28e9f75d5 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Thu, 24 Nov 2022 21:59:16 -0500 Subject: [PATCH 06/24] fix bug: make use of timesteps for for_loop --- .../pipeline_flax_stable_diffusion_img2img.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 13367b30734e..ca5441ba2b46 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -212,13 +212,13 @@ def _generate( noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) def loop_body(step, args): - latents, scheduler_state = args + latents, timesteps, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + t = jnp.array(timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) @@ -236,7 +236,7 @@ def loop_body(step, args): # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, scheduler_state + return latents, timestep,scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape @@ -249,10 +249,10 @@ def loop_body(step, args): if debug: # run with python for loop - for i in range(num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + for i in range(len(timesteps)): + latents, timesteps, scheduler_state = loop_body(i, (latents, timesteps, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + latents, _, _ = jax.lax.fori_loop(0, len(timesteps), loop_body, (latents, timesteps, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From 727fa1d6cde35df520c7f16c06b6eb637bd816da Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 25 Nov 2022 12:11:07 -0500 Subject: [PATCH 07/24] black file --- .../stable_diffusion/pipeline_flax_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index ca5441ba2b46..62f7cebfafd7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -236,7 +236,7 @@ def loop_body(step, args): # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, timestep,scheduler_state + return latents, timestep, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape From e7d66876f57f22800a1748b46fa358b2940530d8 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 25 Nov 2022 12:19:09 -0500 Subject: [PATCH 08/24] black, isort, flake8 --- .../pipeline_flax_stable_diffusion_img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 62f7cebfafd7..50a2f3b3fd6c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -14,7 +14,7 @@ import warnings from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union import numpy as np @@ -34,7 +34,7 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) -from ...utils import logging, PIL_INTERPOLATION +from ...utils import PIL_INTERPOLATION, logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From c8787c8726e9793024a5e31f2b620238e42a92f2 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 25 Nov 2022 20:00:26 -0500 Subject: [PATCH 09/24] update docstring --- .../pipeline_flax_stable_diffusion_img2img.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 50a2f3b3fd6c..f91a9bd340b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -266,7 +266,7 @@ def __call__( prompt_ids: jnp.array, init_images: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, + prng_seed: Union[jax.random.KeyArray, jax.Array], num_inference_steps: int = 50, height: int = 512, width: int = 512, @@ -281,38 +281,38 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): + prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. + init_image (`jnp.array`): + Array representing an image batch, that will be used as the starting point for the + process. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. - + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a From 94a3e936fd95dc14772e5bcae533a1101e341c5b Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 2 Dec 2022 22:29:20 -0500 Subject: [PATCH 10/24] update readme --- README.md | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/README.md b/README.md index ff523d060c59..ca3f87db3238 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,51 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True). images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` +Diffusers also has a Image-to-Image generation pipeline with Flax/Jax +```python +import jax +import numpy as np +from flax.jax_utils import replicate +from flax.training.common_utils import shard +import requests +from io import BytesIO +from PIL import Image +from diffusers import FlaxStableDiffusionImg2ImgPipeline + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_img = Image.open(BytesIO(response.content)).convert("RGB") +init_img = init_img.resize((768, 512)) +prompts = "A fantasy landscape, trending on artstation" +dtype=jnp.bfloat16 +pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="flax", + dtype=dtype, +) +def create_key(seed=0): + return jax.random.PRNGKey(seed) +rng = create_key(0) +rng = jax.random.split(rng, jax.device_count()) + +prompt_ids, imgs = pipeline.prepare_inputs(prompt=[prompts]*jax.device_count(), init_image = [init_img]*jax.device_count()) +p_params = replicate(params) +prompt_ids = shard(prompt_ids) +imgs = shard(imgs) + +output = pipeline( + prompt_ids=prompt_ids, + init_images=imgs, + params=p_params, + prng_seed=rng, + strength=0.75, + num_inference_steps=50, + jit=True, + init_image=imgs, + guidance_scale=7.5, +height=512,width=768).images +``` + ### Image-to-Image text-guided generation with Stable Diffusion The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. From 7696bf754ffe9b1aab238c7bc37e987afdfc95a6 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 2 Dec 2022 22:31:58 -0500 Subject: [PATCH 11/24] update flax img2img readme --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index ca3f87db3238..fa2df225332d 100644 --- a/README.md +++ b/README.md @@ -288,8 +288,7 @@ output = pipeline( num_inference_steps=50, jit=True, init_image=imgs, - guidance_scale=7.5, -height=512,width=768).images + height=512,width=768).images ``` ### Image-to-Image text-guided generation with Stable Diffusion From c5a427580dd67c13f23360fd39a284af0e355b3a Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Fri, 2 Dec 2022 22:36:43 -0500 Subject: [PATCH 12/24] update sd pipeline init --- src/diffusers/pipelines/stable_diffusion/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 3c012dbab89d..a155d825ac52 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -75,4 +75,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From 83f2f77a23eafb6f7e2f1ab07b926330b0fed3b1 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Sun, 4 Dec 2022 21:18:47 -0500 Subject: [PATCH 13/24] Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca --- .../stable_diffusion/pipeline_flax_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index f91a9bd340b3..f8afff5236b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -44,7 +44,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for image-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) From bc4abd6123b3203273bd7b1c4f7d650fcb928461 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 6 Dec 2022 09:42:45 -0500 Subject: [PATCH 14/24] Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca --- .../stable_diffusion/pipeline_flax_stable_diffusion_img2img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index f8afff5236b9..b0ac6370c576 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -275,7 +275,6 @@ def __call__( return_dict: bool = True, jit: bool = False, debug: bool = False, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. From ff2be25d6b0d413e2cbbfdc4aed49ad26002b37c Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 6 Dec 2022 13:08:39 -0500 Subject: [PATCH 15/24] update inits --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4a6661b6b393..96febb0ce3db 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -115,5 +115,6 @@ if is_flax_available() and is_transformers_available(): from .pipelines import FlaxStableDiffusionPipeline + from .pipelines import FlaxStableDiffusionImg2ImgPipeline else: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9f4cef4b73e6..ebb92f423ad5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -45,3 +45,4 @@ if is_transformers_available() and is_flax_available(): from .stable_diffusion import FlaxStableDiffusionPipeline + from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline From 6cec0b8c6ee02773cd96c72fafe182150b527f46 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 6 Dec 2022 13:39:21 -0500 Subject: [PATCH 16/24] revert change --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 914d0de619fd..d86847fad653 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -428,7 +428,7 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) - def get_timesteps(self, num_inference_steps, strength): + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset @@ -567,7 +567,7 @@ def __call__( # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength) + timesteps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables From e124ed9cda29929226f43fdd8c7ed22d543e572f Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Sun, 11 Dec 2022 23:53:18 +0000 Subject: [PATCH 17/24] update var name to image, typo --- .../pipeline_flax_stable_diffusion_img2img.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index b0ac6370c576..0362dd0a20dc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -107,19 +107,19 @@ def __init__( feature_extractor=feature_extractor, ) - def prepare_inputs(self, prompt: Union[str, List[str]], init_image: Union[Image.Image, List[Image.Image]]): + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if not isinstance(init_image, (Image.Image, list)): - raise ValueError(f"init_image has to be of type `PIL.Image.Image` or list but is {type(init_image)}") + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - if isinstance(init_image, Image.Image): - init_image = [init_image] - processed_images = [] - for img in init_image: - processed_images.append(preprocess(img, self.dtype)) - processed_images = jnp.array(processed_images).squeeze() + if isinstance(image, Image.Image): + image = [image] + processed_image = [] + for img in image: + processed_image.append(preprocess(img, self.dtype)) + processed_image = jnp.array(processed_image).squeeze() text_input = self.tokenizer( prompt, @@ -128,7 +128,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]], init_image: Union[Image. truncation=True, return_tensors="np", ) - return text_input.input_ids, processed_images + return text_input.input_ids, processed_image def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) @@ -177,7 +177,7 @@ def get_timesteps(self, num_inference_steps, strength, scheduler_state): def _generate( self, prompt_ids: jnp.array, - init_image: jnp.array, + image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, strength: float = 0.8, @@ -205,7 +205,7 @@ def _generate( context = jnp.concatenate([uncond_embeddings, text_embeddings]) # Create init_latents - init_latent_dist = self.vae.apply({"params": params["vae"]}, init_image, method=self.vae.encode).latent_dist + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = 0.18215 * init_latents latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) @@ -236,7 +236,7 @@ def loop_body(step, args): # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, timestep, scheduler_state + return latents, timesteps, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape @@ -264,7 +264,7 @@ def loop_body(step, args): def __call__( self, prompt_ids: jnp.array, - init_images: jnp.array, + image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: Union[jax.random.KeyArray, jax.Array], num_inference_steps: int = 50, @@ -282,7 +282,7 @@ def __call__( Args: prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. - init_image (`jnp.array`): + image (`jnp.array`): Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights @@ -301,8 +301,8 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in return_dict (`bool`, *optional*, defaults to `True`): @@ -320,10 +320,10 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ if jit: - images = _p_generate( + image = _p_generate( self, prompt_ids, - init_images, + image, params, prng_seed, strength, @@ -334,9 +334,9 @@ def __call__( debug, ) else: - images = self._generate( + image = self._generate( prompt_ids, - init_images, + image, params, prng_seed, strength, @@ -349,27 +349,27 @@ def __call__( if self.safety_checker is not None: safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] + image_uint8_casted = (image * 255).round().astype("uint8") + num_devices, batch_size = image.shape[:2] - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) + image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit) + image = np.asarray(image) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) + image[i] = np.asarray(image_uint8_casted[i]) - images = images.reshape(num_devices, batch_size, height, width, 3) + image = image.reshape(num_devices, batch_size, height, width, 3) else: has_nsfw_concept = False if not return_dict: - return (images, has_nsfw_concept) + return (image, has_nsfw_concept) - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) # TODO: maybe use a config dict instead of so many static argnums @@ -377,7 +377,7 @@ def __call__( def _p_generate( pipe, prompt_ids, - init_images, + image, params, prng_seed, strength, @@ -388,7 +388,7 @@ def _p_generate( debug, ): return pipe._generate( - prompt_ids, init_images, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug ) From c1e699619be299b6bdd0831e79f40481deaa1553 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 12 Dec 2022 00:02:46 +0000 Subject: [PATCH 18/24] update readme --- README.md | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index fa2df225332d..053a1362a851 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ Diffusers also has a Image-to-Image generation pipeline with Flax/Jax ```python import jax import numpy as np +import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard import requests @@ -258,37 +259,41 @@ from io import BytesIO from PIL import Image from diffusers import FlaxStableDiffusionImg2ImgPipeline -url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +def create_key(seed=0): + return jax.random.PRNGKey(seed) +rng = create_key(0) +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" response = requests.get(url) init_img = Image.open(BytesIO(response.content)).convert("RGB") init_img = init_img.resize((768, 512)) + prompts = "A fantasy landscape, trending on artstation" -dtype=jnp.bfloat16 + pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="flax", - dtype=dtype, + dtype=jnp.bfloat16, ) -def create_key(seed=0): - return jax.random.PRNGKey(seed) -rng = create_key(0) -rng = jax.random.split(rng, jax.device_count()) -prompt_ids, imgs = pipeline.prepare_inputs(prompt=[prompts]*jax.device_count(), init_image = [init_img]*jax.device_count()) +num_samples = jax.device_count() +rng = jax.random.split(rng, jax.device_count()) +prompt_ids, processed_image = pipeline.prepare_inputs(prompt=[prompts]*num_samples, image = [init_img]*num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) -imgs = shard(imgs) +processed_image = shard(processed_image) output = pipeline( prompt_ids=prompt_ids, - init_images=imgs, + image=processed_image, params=p_params, prng_seed=rng, strength=0.75, num_inference_steps=50, jit=True, - init_image=imgs, - height=512,width=768).images + height=512, + width=768).images + +output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) ``` ### Image-to-Image text-guided generation with Stable Diffusion From 932a74e382078d68c1629e865576261ce6d136d3 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 12 Dec 2022 00:33:34 +0000 Subject: [PATCH 19/24] return new t_start instead of modified timestep --- .../pipeline_flax_stable_diffusion_img2img.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 0362dd0a20dc..3e1d448322db 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -164,15 +164,14 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): return images, has_nsfw_concepts - def get_timesteps(self, num_inference_steps, strength, scheduler_state): + def get_timestep_start(self, num_inference_steps, strength, scheduler_state): # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = scheduler_state.timesteps[t_start:] - return timesteps + return t_start def _generate( self, @@ -212,13 +211,13 @@ def _generate( noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) def loop_body(step, args): - latents, timesteps, scheduler_state = args + latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) - t = jnp.array(timesteps, dtype=jnp.int32)[step] + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) @@ -236,23 +235,23 @@ def loop_body(step, args): # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, timesteps, scheduler_state + return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) - timesteps = self.get_timesteps(num_inference_steps, strength, scheduler_state) - latent_timestep = timesteps[:1].repeat(batch_size) + t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) + latent_timestep = scheduler_state.timesteps[t_start:t_start+1].repeat(batch_size) init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) latents = init_latents if debug: # run with python for loop - for i in range(len(timesteps)): - latents, timesteps, scheduler_state = loop_body(i, (latents, timesteps, scheduler_state)) + for i in range(t_start, len(scheduler_state.timesteps)): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _, _ = jax.lax.fori_loop(0, len(timesteps), loop_body, (latents, timesteps, scheduler_state)) + latents, _ = jax.lax.fori_loop(t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From e8dff83ecf19a9706d4f6363812511dd299a3857 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 12 Dec 2022 00:42:12 +0000 Subject: [PATCH 20/24] black format --- .../pipeline_flax_stable_diffusion_img2img.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 3e1d448322db..d7911195a6e0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -242,7 +242,7 @@ def loop_body(step, args): ) t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) - latent_timestep = scheduler_state.timesteps[t_start:t_start+1].repeat(batch_size) + latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size) init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) latents = init_latents @@ -251,7 +251,9 @@ def loop_body(step, args): for i in range(t_start, len(scheduler_state.timesteps)): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop(t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state)) + latents, _ = jax.lax.fori_loop( + t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state) + ) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents @@ -319,7 +321,7 @@ def __call__( "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ if jit: - image = _p_generate( + image = _p_generate( self, prompt_ids, image, From d06f8c38124ad76ac285fee5a33a4a68fd8ecdc7 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 12 Dec 2022 00:55:25 +0000 Subject: [PATCH 21/24] isort files --- src/diffusers/__init__.py | 3 +-- src/diffusers/pipelines/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eeddedfc992b..e315b150a5d9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -161,5 +161,4 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: - from .pipelines import FlaxStableDiffusionPipeline - from .pipelines import FlaxStableDiffusionImg2ImgPipeline + from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2f33ac29482e..75e9efa99db0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -90,5 +90,4 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: - from .stable_diffusion import FlaxStableDiffusionPipeline - from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline + from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline From b9fa9b7996fd2e065bdede791f1a8f491c5a98d3 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 12 Dec 2022 01:04:34 +0000 Subject: [PATCH 22/24] update docs --- .../pipeline_flax_stable_diffusion_img2img.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index d7911195a6e0..43b868a48aad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -284,8 +284,7 @@ def __call__( prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): - Array representing an image batch, that will be used as the starting point for the - process. + Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key num_inference_steps (`int`, *optional*, defaults to 50): @@ -302,10 +301,10 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. From fb0af342d1f76188f4efb96a0c3174cd9c616224 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 19 Dec 2022 04:37:01 +0000 Subject: [PATCH 23/24] fix-copies --- .../utils/dummy_flax_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 14830bca2898..3f27bf413977 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -4,6 +4,21 @@ from ..utils import DummyObject, requires_backends +class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + class FlaxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] From cb88c46db924a5866b5c1214457069898efbe5ca Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Mon, 19 Dec 2022 04:54:06 +0000 Subject: [PATCH 24/24] update prng_seed typing --- .../stable_diffusion/pipeline_flax_stable_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 43b868a48aad..7b0b35f89e00 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -267,7 +267,7 @@ def __call__( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: Union[jax.random.KeyArray, jax.Array], + prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: int = 512, width: int = 512,