diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 857fdd1b0b33..be9b7dee5e86 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -189,7 +189,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): ```""" return self._cast_floating_to(params, jnp.float16, mask) - def init_weights(self, rng: jax.random.PRNGKey) -> Dict: + def init_weights(self, rng: jax.random.KeyArray) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 3a3f1d9e146d..fb9caca9e1ed 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos: bool = True freq_shift: int = 0 - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 7ecda9a6e9a0..9e18d66bc9f4 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -806,7 +806,7 @@ def setup(self): dtype=self.dtype, ) - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index f8fd304776d7..848174230287 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -17,7 +17,7 @@ import importlib import inspect import os -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np @@ -475,6 +475,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = pipeline_class(**init_kwargs, dtype=dtype) return model, params + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + + @property + def components(self) -> Dict[str, Any]: + r""" + + The `self.components` property can be useful to run different pipelines with the same weights and + configurations to not have to re-allocate memory. + + Examples: + + ```py + >>> from diffusers import ( + ... FlaxStableDiffusionPipeline, + ... FlaxStableDiffusionImg2ImgPipeline, + ... ) + + >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 + ... ) + >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) + ``` + + Returns: + A dictionary containing all the modules needed to initialize the pipeline. + """ + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + + if set(components.keys()) != expected_modules: + raise ValueError( + f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" + f" {expected_modules} to be defined, but {components} are defined." + ) + + return components + @staticmethod def numpy_to_pil(images): """ diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 50fbbedc996c..705c9e33e2a4 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -764,7 +764,7 @@ def components(self) -> Dict[str, Any]: ``` Returns: - A dictionaly containing all the modules needed to initialize the pipeline. + A dictionary containing all the modules needed to initialize the pipeline. """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 09760cef2d60..b4f5766869f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -184,18 +184,14 @@ def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, - num_inference_steps: int = 50, - height: Optional[int] = None, - width: Optional[int] = None, - guidance_scale: float = 7.5, + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, latents: Optional[jnp.array] = None, - neg_prompt_ids: jnp.array = None, + neg_prompt_ids: Optional[jnp.array] = None, ): - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -281,15 +277,15 @@ def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, + prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, - neg_prompt_ids: jnp.array = None, ): r""" Function invoked when calling the pipeline for generation. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 7b0b35f89e00..5eeede7d660e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -14,7 +14,7 @@ import warnings from functools import partial -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -41,6 +41,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" @@ -106,6 +109,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): @@ -116,10 +120,8 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image if isinstance(image, Image.Image): image = [image] - processed_image = [] - for img in image: - processed_image.append(preprocess(img, self.dtype)) - processed_image = jnp.array(processed_image).squeeze() + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) text_input = self.tokenizer( prompt, @@ -128,7 +130,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image truncation=True, return_tensors="np", ) - return text_input.input_ids, processed_image + return text_input.input_ids, processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) @@ -164,12 +166,11 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): return images, has_nsfw_concepts - def get_timestep_start(self, num_inference_steps, strength, scheduler_state): + def get_timestep_start(self, num_inference_steps, strength): # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - t_start = max(num_inference_steps - init_timestep + offset, 0) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) return t_start @@ -178,13 +179,14 @@ def _generate( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.PRNGKey, - strength: float = 0.8, - num_inference_steps: int = 50, - height: int = 512, - width: int = 512, - guidance_scale: float = 7.5, - debug: bool = False, + prng_seed: jax.random.KeyArray, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -197,18 +199,32 @@ def _generate( batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = 0.18215 * init_latents - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) def loop_body(step, args): latents, scheduler_state = args @@ -241,19 +257,19 @@ def loop_body(step, args): params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) - t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) - latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size) - init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) - latents = init_latents + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) - if debug: + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: # run with python for loop - for i in range(t_start, len(scheduler_state.timesteps)): + for i in range(start_timestep, num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: - latents, _ = jax.lax.fori_loop( - t_start, len(scheduler_state.timesteps), loop_body, (latents, scheduler_state) - ) + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents @@ -268,14 +284,15 @@ def __call__( image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - height: int = 512, - width: int = 512, - guidance_scale: float = 7.5, strength: float = 0.8, + num_inference_steps: int = 50, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: Union[float, jnp.array] = 7.5, + noise: jnp.array = None, + neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, - debug: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -287,12 +304,17 @@ def __call__( Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). @@ -300,18 +322,17 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in + noise (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - debug (`bool`, *optional*, defaults to `False`): Whether to make use of python forloop or lax.fori_loop + Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a @@ -319,76 +340,109 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + if jit: - image = _p_generate( + images = _p_generate( self, prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ) else: - image = self._generate( + images = self._generate( prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] - image_uint8_casted = (image * 255).round().astype("uint8") - num_devices, batch_size = image.shape[:2] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] - image_uint8_casted = np.asarray(image_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - image_uint8_casted, has_nsfw_concept = self._run_safety_checker(image_uint8_casted, safety_params, jit) - image = np.asarray(image) + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: - image[i] = np.asarray(image_uint8_casted[i]) + images[i] = np.asarray(images_uint8_casted[i]) - image = image.reshape(num_devices, batch_size, height, width, 3) + images = images.reshape(num_devices, batch_size, height, width, 3) else: + images = np.asarray(images) has_nsfw_concept = False if not return_dict: - return (image, has_nsfw_concept) + return (images, has_nsfw_concept) - return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, static_broadcasted_argnums=(0, 5, 6, 7, 8, 9, 10)) +# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6, 7, 8), +) def _p_generate( pipe, prompt_ids, image, params, prng_seed, - strength, + start_timestep, num_inference_steps, height, width, guidance_scale, - debug, + noise, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, image, params, prng_seed, strength, num_inference_steps, height, width, guidance_scale, debug + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, ) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index e1f669d22b76..71b7306134a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -87,7 +87,7 @@ def __init__( module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape)