From 1ee4cdeca53a69f05159aa81bcdae5ee79e8ef25 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 17 Mar 2023 20:28:51 +0000 Subject: [PATCH 01/13] add contronet flax --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/controlnet_flax.py | 381 ++++++++++++++++++ .../models/unet_2d_condition_flax.py | 17 + 4 files changed, 400 insertions(+) create mode 100644 src/diffusers/models/controlnet_flax.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f480b4100907..e22b2c0dacfb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -179,6 +179,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL + from .models.controlnet_flax import FlaxControlNetModel from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0b2cddd4bf9..c71b26a17df5 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -29,4 +29,5 @@ if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .controlnet_flax import FlaxControlNetModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py new file mode 100644 index 000000000000..74b8fb9fff57 --- /dev/null +++ b/src/diffusers/models/controlnet_flax.py @@ -0,0 +1,381 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, +) + + +@flax.struct.dataclass +class FlaxControlNetOutput(BaseOutput): + """ + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + down_block_res_samples: jnp.ndarray + mid_block_res_sample: jnp.ndarray + + +class FlaxControlNetConditioningEmbedding(nn.Module): + conditioning_embedding_channels: int + block_out_channels: Tuple[int] = (16, 32, 96, 256) + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_in = nn.Conv( + self.block_out_channels[0], + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + blocks = [] + for i in range(len(self.block_out_channels) -1): + channel_in = self.block_out_channels[i] + channel_out = self.block_out_channels[i + 1] + conv1 = nn.Conv( + channel_in, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + blocks.append(conv1) + conv2 = nn.Conv( + channel_out, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + dtype = self.dtype, + ) + blocks.append(conv2) + self.blocks = blocks + + self.conv_out = nn.Conv( + self.conditioning_embedding_channels, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = nn.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = nn.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +@flax_register_to_config +class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", + "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + + """ + sample_size: int = 32 + in_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + only_cross_attention: Union[bool, Tuple[bool]] = False + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: Union[int, Tuple[int]] = 8 + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + controlnet_conditioning_channel_order: str = "rgb" + conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) + + def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) + controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=self.conditioning_embedding_out_channels, + ) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + attention_head_dim = self.attention_head_dim + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(self.down_block_types) + + # down + down_blocks = [] + controlnet_down_blocks = [] + + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1,1), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=attention_head_dim[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + + for _ in range(self.layers_per_block): + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1,1), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv( + output_channel, + kernel_size=(1,1), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + controlnet_down_blocks.append(controlnet_block) + + self.down_blocks = down_blocks + self.controlnet_down_blocks = controlnet_down_blocks + + # mid + mid_block_channel = block_out_channels[-1] + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=mid_block_channel, + dropout=self.dropout, + attn_num_head_channels=attention_head_dim[-1], + use_linear_projection=self.use_linear_projection, + dtype=self.dtype, + ) + + self.controlnet_mid_block = nn.Conv( + mid_block_channel, + kernel_size=(1,1), + padding=((1, 1), (1, 1)), + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + controlnet_cond, + conditioning_scale: float = 1.0, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxControlNetOutput, Tuple]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + channel_order = self.controlnet_conditioning_channel_order + if channel_order == 'bgr': + controlnet_cond = jnp.flip(controlnet_cond, axis=1) + + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample += controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. contronet blocks + controlnet_down_block_res_samples = () + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return FlaxControlNetOutput(down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index a40473a25f55..69c4e92bb63e 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -250,6 +250,9 @@ def __call__( timesteps, encoder_hidden_states, return_dict: bool = True, + down_block_additional_residuals = None, + mid_block_additional_residual = None, + return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" @@ -290,10 +293,24 @@ def __call__( else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] From 02a1bf861064c9c9043f43137b7fce6b33083448 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 17 Mar 2023 20:49:04 +0000 Subject: [PATCH 02/13] fix --- src/diffusers/models/unet_2d_condition_flax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 69c4e92bb63e..85d9bad173a3 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -249,7 +249,6 @@ def __call__( sample, timesteps, encoder_hidden_states, - return_dict: bool = True, down_block_additional_residuals = None, mid_block_additional_residual = None, return_dict: bool = True, @@ -310,7 +309,7 @@ def __call__( if mid_block_additional_residual is not None: sample += mid_block_additional_residual - + # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] From f2d66fd0a69f9d8d6e1b1a25adc45e6b35bc65bd Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 17 Mar 2023 23:50:22 +0000 Subject: [PATCH 03/13] fix padding --- src/diffusers/models/controlnet_flax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 74b8fb9fff57..9f140cbb27c5 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -221,7 +221,7 @@ def setup(self): controlnet_block = nn.Conv( output_channel, kernel_size=(1,1), - padding=((1, 1), (1, 1)), + padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, @@ -261,7 +261,7 @@ def setup(self): controlnet_block = nn.Conv( output_channel, kernel_size=(1,1), - padding=((1, 1), (1, 1)), + padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, @@ -272,7 +272,7 @@ def setup(self): controlnet_block = nn.Conv( output_channel, kernel_size=(1,1), - padding=((1, 1), (1, 1)), + padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, @@ -295,7 +295,7 @@ def setup(self): self.controlnet_mid_block = nn.Conv( mid_block_channel, kernel_size=(1,1), - padding=((1, 1), (1, 1)), + padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, From 3e024da5165c582014411759701a085e0f78553e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 19 Mar 2023 05:52:48 +0000 Subject: [PATCH 04/13] add pipeline --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 1 + ...peline_flax_stable_diffusion_controlnet.py | 517 ++++++++++++++++++ 4 files changed, 520 insertions(+) create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e22b2c0dacfb..1b83dbb2f607 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -203,4 +203,5 @@ FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, + FlaxStableDiffusionControlNetPipeline, ) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b6c729f80be..dbb75ef759de 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -124,4 +124,5 @@ FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, + FlaxStableDiffusionControlNetPipeline, ) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 54ec4dabc73e..8addee84cef9 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -129,4 +129,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline + from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py new file mode 100644 index 000000000000..76cd01b5fdad --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -0,0 +1,517 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxControlNetModel +from ...models.controlnet_flax import FlaxControlNetOutput +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_img = init_img.resize((768, 512)) + + >>> prompts = "A fantasy landscape, trending on artstation" + + >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... revision="flax", + ... dtype=jnp.bfloat16, + ... ) + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + >>> prompt_ids, processed_image = pipeline.prepare_inputs( + ... prompt=[prompts] * num_samples, image=[init_img] * num_samples + ... ) + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipeline( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... strength=0.75, + ... num_inference_steps=50, + ... jit=True, + ... height=512, + ... width=768, + ... ).images + + >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals = down_block_res_samples, + mid_block_additional_residual = mid_block_res_sample + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing an image batch, that will be used as the starting point for the process. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. tensor will ge generated + by sampling using the supplied random `generator`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, start_timestep, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6,), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image_.convert("RGB") + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 From 919ba4993a7da3499c91dadf196c97998d4b8eaf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 19 Mar 2023 19:42:45 +0000 Subject: [PATCH 05/13] allow from_pretrained(...controlnet=controlnet) --- src/diffusers/pipelines/pipeline_flax_utils.py | 17 +++++++++++++++-- ...pipeline_flax_stable_diffusion_controlnet.py | 14 ++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 30e32c3d66e9..9081a929f3d8 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -278,7 +278,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" - >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( + >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) @@ -365,7 +365,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -469,6 +469,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Potentially add passed objects if expected + missing_modules = set(expected_modules) - set(init_kwargs.keys()) + passed_modules = list(passed_class_obj.keys()) + + if len(missing_modules) > 0 and missing_modules <= set(passed_modules): + for module in missing_modules: + init_kwargs[module] = passed_class_obj.get(module, None) + elif len(missing_modules) > 0: + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs + raise ValueError( + f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." + ) model = pipeline_class(**init_kwargs, dtype=dtype) return model, params diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 76cd01b5fdad..3ad3b238946d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -164,6 +164,7 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -259,6 +260,8 @@ def _generate( negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + image = jnp.concatenate([image] * 2) + latents_shape = ( batch_size, self.unet.in_channels, @@ -395,6 +398,9 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + + height, width = image.shape[-2:] + if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. @@ -467,8 +473,8 @@ def __call__( # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, - in_axes=(None, 0, 0, 0, 0, None, None, 0, 0, 0), - static_broadcasted_argnums=(0, 5, 6,), + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), ) def _p_generate( pipe, @@ -508,10 +514,10 @@ def unshard(x: jnp.ndarray): def preprocess(image, dtype): - image = image_.convert("RGB") + image = image.convert("RGB") w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 + return image From bfddf0d2c6a740b46fdb9c3fe24a3c6f5f7c472b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 20 Mar 2023 06:44:33 +0000 Subject: [PATCH 06/13] prepare_text_inputs + prepare_image_inputs --- ...peline_flax_stable_diffusion_controlnet.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 3ad3b238946d..5a43f98be736 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -171,10 +171,21 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): + def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") @@ -183,14 +194,7 @@ def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - return text_input.input_ids, processed_images + return processed_images, def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) From 9f6bedf640a4406990c55a4682b533bece9a8e0d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 20 Mar 2023 06:48:45 +0000 Subject: [PATCH 07/13] fix --- .../pipeline_flax_stable_diffusion_controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 5a43f98be736..14caad0e1e6a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -171,7 +171,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def prepare_text_inputs(self, prompt: Union[str, List[str]]): + def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") From 37d4ac11b59938cf33c1bb38ef87649593e9a53b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 20 Mar 2023 07:44:06 +0000 Subject: [PATCH 08/13] fix doc string --- src/diffusers/models/controlnet_flax.py | 26 +++--- ...peline_flax_stable_diffusion_controlnet.py | 81 ++++++++++--------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 9f140cbb27c5..54ad78c1b4d0 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -32,12 +32,6 @@ @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): - """ - Args: - sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - down_block_res_samples: jnp.ndarray mid_block_res_sample: jnp.ndarray @@ -101,8 +95,12 @@ def __call__(self, conditioning): @flax_register_to_config class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) @@ -122,14 +120,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): - The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", - "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): @@ -143,6 +136,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): + The channel order of conditional image. Will convert it to `rgb` if it's `bgr` + conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in conditioning_embedding layer + """ sample_size: int = 32 @@ -316,6 +314,8 @@ def __call__( sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor + conditioning_scale: (`float`) the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 14caad0e1e6a..c4f4f4d3a704 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -52,53 +52,58 @@ >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard - >>> import requests - >>> from io import BytesIO + >>> from diffusers.utils import load_image >>> from PIL import Image - >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + >>> def image_grid(imgs, rows, cols): + ... w,h = imgs[0].size + ... grid = Image.new('RGB', size=(cols*w, rows*h)) + ... for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) + ... return grid >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) - >>> rng = create_key(0) - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> response = requests.get(url) - >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") - >>> init_img = init_img.resize((768, 512)) + >>> # get canny image + >>> canny_image = load_image("https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg") - >>> prompts = "A fantasy landscape, trending on artstation" + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" - >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( - ... "CompVis/stable-diffusion-v1-4", - ... revision="flax", - ... dtype=jnp.bfloat16, - ... ) + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype = jnp.float32) + >>> params['controlnet'] = controlnet_params >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) - >>> prompt_ids, processed_image = pipeline.prepare_inputs( - ... prompt=[prompts] * num_samples, image=[init_img] * num_samples - ... ) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) >>> processed_image = shard(processed_image) - >>> output = pipeline( + >>> output = pipe( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, - ... strength=0.75, ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, ... jit=True, - ... height=512, - ... width=768, ... ).images - >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = image_grid(output_images, num_samples//4, 4) + >>> output_images.save("generated_image.png") ``` """ @@ -121,6 +126,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the unet during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or @@ -194,7 +201,7 @@ def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) - return processed_images, + return processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) @@ -243,8 +250,8 @@ def _generate( controlnet_conditioning_scale: float = 1.0, ): height, width = image.shape[-2:] - if height % 32 != 0 or width % 32 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] @@ -361,31 +368,25 @@ def __call__( prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): - Array representing an image batch, that will be used as the starting point for the process. + Array representing the ControlNet input condition. ControlNet use this input condition to generate guidance to Unet. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - noise (`jnp.array`, *optional*): + latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. tensor will ge generated - by sampling using the supplied random `generator`. + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. @@ -473,7 +474,7 @@ def __call__( return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# Static argnums are pipe, start_timestep, num_inference_steps. A change would trigger recompilation. +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, @@ -520,7 +521,7 @@ def unshard(x: jnp.ndarray): def preprocess(image, dtype): image = image.convert("RGB") w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) From 4097137651f34565aeef542e35be675d03825f7f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Mar 2023 02:10:52 +0000 Subject: [PATCH 09/13] add doc --- docs/source/en/api/models.mdx | 6 ++++++ .../source/en/api/pipelines/stable_diffusion/controlnet.mdx | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index dc425e98628c..39ccc1808f52 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -87,3 +87,9 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## FlaxAutoencoderKL [[autodoc]] FlaxAutoencoderKL + +## FlaxControlNetOutput +[[autodoc]] models.controlnet_flax.FlaxControlNetOutput + +## FlaxControlNetModel +[[autodoc]] FlaxControlNetModel diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx index b5fa350e5f04..90df682264af 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx @@ -165,3 +165,9 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h - disable_vae_slicing - enable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention + +## FlaxStableDiffusionControlNetPipeline +[[autodoc]] FlaxStableDiffusionControlNetPipeline + - all + - __call__ + From 84e1caf8a890893eb8b259b3f29c8c5cf2655f86 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Mar 2023 02:11:20 +0000 Subject: [PATCH 10/13] add tests --- .../test_stable_diffusion_flax_controlnet.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py new file mode 100644 index 000000000000..227b9cdb2431 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline +from diffusers.utils import is_flax_available, slow, load_image +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.jax_utils import replicate + from flax.training.common_utils import shard + + +@slow +@require_flax +class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def test_canny(self): + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + "lllyasviel/sd-controlnet-canny", + from_pt=True, + dtype=jnp.bfloat16 + ) + pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + controlnet=controlnet, from_pt=True, + dtype = jnp.bfloat16 + ) + params['controlnet'] = controlnet_params + + prompts = "bird" + num_samples = jax.device_count() + prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + + canny_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + rng = jax.random.PRNGKey(0) + rng = jax.random.split(rng, jax.device_count()) + + p_params = replicate(params) + prompt_ids = shard(prompt_ids) + processed_image = shard(processed_image) + + images = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=50, + jit=True, + ).images + assert images.shape == (jax.device_count(), 1, 768, 512, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + def test_pose(self): + controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + "lllyasviel/sd-controlnet-openpose", + from_pt=True, + dtype=jnp.bfloat16 + ) + pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + controlnet=controlnet, from_pt=True, + dtype = jnp.bfloat16 + ) + params['controlnet'] = controlnet_params + + prompts = "Chef in the kitchen" + num_samples = jax.device_count() + prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + + pose_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png" + ) + processed_image = pipe.prepare_image_inputs([pose_image] * num_samples) + + rng = jax.random.PRNGKey(0) + rng = jax.random.split(rng, jax.device_count()) + + p_params = replicate(params) + prompt_ids = shard(prompt_ids) + processed_image = shard(processed_image) + + images = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=50, + jit=True, + ).images + assert images.shape == (jax.device_count(), 1, 768, 512, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + From 5d42d23f1897551cc7875a4f87ef9a18142f814c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Mar 2023 02:15:01 +0000 Subject: [PATCH 11/13] make style --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/controlnet_flax.py | 70 ++++++++++--------- .../models/unet_2d_condition_flax.py | 8 +-- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/pipeline_flax_utils.py | 2 +- .../pipelines/stable_diffusion/__init__.py | 2 +- ...peline_flax_stable_diffusion_controlnet.py | 13 ++-- .../test_stable_diffusion_flax_controlnet.py | 40 +++++------ 9 files changed, 69 insertions(+), 74 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1b83dbb2f607..603d585065e1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -176,10 +176,10 @@ except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: + from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL - from .models.controlnet_flax import FlaxControlNetModel from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, @@ -200,8 +200,8 @@ from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipelines import ( + FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, - FlaxStableDiffusionControlNetPipeline, ) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c71b26a17df5..bb56701e716c 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,6 @@ from .vq_model import VQModel if is_flax_available(): - from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .controlnet_flax import FlaxControlNetModel + from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 54ad78c1b4d0..3adefa84ea68 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -47,27 +47,27 @@ def setup(self): kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, - ) - + ) + blocks = [] - for i in range(len(self.block_out_channels) -1): + for i in range(len(self.block_out_channels) - 1): channel_in = self.block_out_channels[i] channel_out = self.block_out_channels[i + 1] conv1 = nn.Conv( - channel_in, - kernel_size=(3, 3), - padding=((1, 1), (1, 1)), + channel_in, + kernel_size=(3, 3), + padding=((1, 1), (1, 1)), dtype=self.dtype, - ) + ) blocks.append(conv1) conv2 = nn.Conv( - channel_out, - kernel_size=(3, 3), - strides=(2, 2), - padding=((1, 1), (1, 1)), - dtype = self.dtype, + channel_out, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), + dtype=self.dtype, ) - blocks.append(conv2) + blocks.append(conv2) self.blocks = blocks self.conv_out = nn.Conv( @@ -77,7 +77,7 @@ def setup(self): kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, - ) + ) def __call__(self, conditioning): embedding = self.conv_in(conditioning) @@ -139,7 +139,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): The channel order of conditional image. Will convert it to `rgb` if it's `bgr` conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in conditioning_embedding layer + The tuple of output channel for each block in conditioning_embedding layer """ @@ -196,7 +196,7 @@ def setup(self): block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) - + self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=self.conditioning_embedding_out_channels, @@ -217,13 +217,13 @@ def setup(self): output_channel = block_out_channels[0] controlnet_block = nn.Conv( - output_channel, - kernel_size=(1,1), + output_channel, + kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, - ) + ) controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(self.down_block_types): @@ -257,24 +257,24 @@ def setup(self): for _ in range(self.layers_per_block): controlnet_block = nn.Conv( - output_channel, - kernel_size=(1,1), + output_channel, + kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, - ) + ) controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv( - output_channel, - kernel_size=(1,1), + output_channel, + kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, - ) + ) controlnet_down_blocks.append(controlnet_block) self.down_blocks = down_blocks @@ -291,13 +291,13 @@ def setup(self): ) self.controlnet_mid_block = nn.Conv( - mid_block_channel, - kernel_size=(1,1), + mid_block_channel, + kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, - ) + ) def __call__( self, @@ -306,7 +306,7 @@ def __call__( encoder_hidden_states, controlnet_cond, conditioning_scale: float = 1.0, - return_dict: bool = True, + return_dict: bool = True, train: bool = False, ) -> Union[FlaxControlNetOutput, Tuple]: r""" @@ -328,9 +328,9 @@ def __call__( When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order - if channel_order == 'bgr': + if channel_order == "bgr": controlnet_cond = jnp.flip(controlnet_cond, axis=1) - + # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) @@ -344,7 +344,7 @@ def __call__( # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) - + controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample += controlnet_cond @@ -368,7 +368,7 @@ def __call__( controlnet_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples - + mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling @@ -378,4 +378,6 @@ def __call__( if not return_dict: return (down_block_res_samples, mid_block_res_sample) - return FlaxControlNetOutput(down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample) + return FlaxControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 85d9bad173a3..812ca079db38 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -249,9 +249,9 @@ def __call__( sample, timesteps, encoder_hidden_states, - down_block_additional_residuals = None, - mid_block_additional_residual = None, - return_dict: bool = True, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" @@ -292,7 +292,7 @@ def __call__( else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples - + if down_block_additional_residuals is not None: new_down_block_res_samples = () diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index dbb75ef759de..f83412aa799f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -121,8 +121,8 @@ from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .stable_diffusion import ( + FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, - FlaxStableDiffusionControlNetPipeline, ) diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 9081a929f3d8..d3fc415ab4d7 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -469,7 +469,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - + # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8addee84cef9..b386ab04c167 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -127,7 +127,7 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline - from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index c4f4f4d3a704..47cdb255bcb0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -25,8 +25,7 @@ from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel -from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel, FlaxControlNetModel -from ...models.controlnet_flax import FlaxControlNetOutput +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, @@ -313,8 +312,8 @@ def loop_body(step, args): jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, - down_block_additional_residuals = down_block_res_samples, - mid_block_additional_residual = mid_block_res_sample + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance @@ -386,7 +385,7 @@ def __call__( tensor will ge generated by sampling using the supplied random `generator`. controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. + to the residual in the original unet. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. @@ -413,7 +412,7 @@ def __call__( if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] - + if isinstance(controlnet_conditioning_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. @@ -421,7 +420,7 @@ def __call__( if len(prompt_ids.shape) > 2: # Assume sharded controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] - + if jit: images = _p_generate( self, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py index 227b9cdb2431..268c01320177 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py @@ -17,7 +17,7 @@ import unittest from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline -from diffusers.utils import is_flax_available, slow, load_image +from diffusers.utils import is_flax_available, load_image, slow from diffusers.utils.testing_utils import require_flax @@ -38,17 +38,13 @@ def tearDown(self): def test_canny(self): controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - "lllyasviel/sd-controlnet-canny", - from_pt=True, - dtype=jnp.bfloat16 + "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16 ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - controlnet=controlnet, from_pt=True, - dtype = jnp.bfloat16 + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 ) - params['controlnet'] = controlnet_params - + params["controlnet"] = controlnet_params + prompts = "bird" num_samples = jax.device_count() prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) @@ -72,30 +68,28 @@ def test_canny(self): prng_seed=rng, num_inference_steps=50, jit=True, - ).images + ).images assert images.shape == (jax.device_count(), 1, 768, 512, 3) images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) image_slice = images[0, 253:256, 253:256, -1] output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array([0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]) + expected_slice = jnp.array( + [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078] + ) print(f"output_slice: {output_slice}") assert jnp.abs(output_slice - expected_slice).max() < 1e-2 def test_pose(self): controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - "lllyasviel/sd-controlnet-openpose", - from_pt=True, - dtype=jnp.bfloat16 + "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16 ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - controlnet=controlnet, from_pt=True, - dtype = jnp.bfloat16 + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16 ) - params['controlnet'] = controlnet_params - + params["controlnet"] = controlnet_params + prompts = "Chef in the kitchen" num_samples = jax.device_count() prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) @@ -119,15 +113,15 @@ def test_pose(self): prng_seed=rng, num_inference_steps=50, jit=True, - ).images + ).images assert images.shape == (jax.device_count(), 1, 768, 512, 3) images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) image_slice = images[0, 253:256, 253:256, -1] output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) - expected_slice = jnp.array([[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]) + expected_slice = jnp.array( + [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]] + ) print(f"output_slice: {output_slice}") assert jnp.abs(output_slice - expected_slice).max() < 1e-2 - - From 408d36835f82d5cb86d1fbcefbe5694acc61fb24 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Mar 2023 02:33:59 +0000 Subject: [PATCH 12/13] fix-copies --- .../utils/dummy_flax_and_transformers_objects.py | 15 +++++++++++++++ src/diffusers/utils/dummy_flax_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py index 5db4c7d58d1e..162bac1c4331 100644 --- a/src/diffusers/utils/dummy_flax_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax", "transformers"]) + + class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 7772c1a06b49..2bb80d136f33 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FlaxControlNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] From 949c989d493513acec98300f8aae175574aed339 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 22 Mar 2023 02:39:07 +0000 Subject: [PATCH 13/13] style --- ...peline_flax_stable_diffusion_controlnet.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 47cdb255bcb0..4dc450cebc84 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -55,28 +55,37 @@ >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + >>> def image_grid(imgs, rows, cols): - ... w,h = imgs[0].size - ... grid = Image.new('RGB', size=(cols*w, rows*h)) - ... for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) + ... w, h = imgs[0].size + ... grid = Image.new("RGB", size=(cols * w, rows * h)) + ... for i, img in enumerate(imgs): + ... grid.paste(img, box=(i % cols * w, i // cols * h)) ... return grid + >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) + >>> rng = create_key(0) >>> # get canny image - >>> canny_image = load_image("https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg") + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) >>> prompts = "best quality, extremely detailed" >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" >>> # load control net and stable diffusion v1-5 - >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32) + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype = jnp.float32) - >>> params['controlnet'] = controlnet_params + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) @@ -101,7 +110,7 @@ ... ).images >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) - >>> output_images = image_grid(output_images, num_samples//4, 4) + >>> output_images = image_grid(output_images, num_samples // 4, 4) >>> output_images.save("generated_image.png") ``` """ @@ -367,7 +376,8 @@ def __call__( prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): - Array representing the ControlNet input condition. ControlNet use this input condition to generate guidance to Unet. + Array representing the ControlNet input condition. ControlNet use this input condition to generate + guidance to Unet. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key num_inference_steps (`int`, *optional*, defaults to 50):