From 67e245c2b5bd85b175219fba17b5ba59d9d8a801 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:23:21 +0200 Subject: [PATCH 1/3] First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. --- src/diffusers/models/attention_flax.py | 181 ++++++++++++ src/diffusers/models/embeddings_flax.py | 56 ++++ src/diffusers/models/resnet_flax.py | 111 ++++++++ .../models/unet_2d_condition_flax.py | 257 +++++++++++++++++ src/diffusers/models/unet_blocks_flax.py | 263 ++++++++++++++++++ 5 files changed, 868 insertions(+) create mode 100644 src/diffusers/models/attention_flax.py create mode 100644 src/diffusers/models/embeddings_flax.py create mode 100644 src/diffusers/models/resnet_flax.py create mode 100644 src/diffusers/models/unet_2d_condition_flax.py create mode 100644 src/diffusers/models/unet_blocks_flax.py diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py new file mode 100644 index 000000000000..77e5ad9c75fb --- /dev/null +++ b/src/diffusers/models/attention_flax.py @@ -0,0 +1,181 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + + +class FlaxAttentionBlock(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + + self.to_out = nn.Dense(self.query_dim, dtype=self.dtype) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + q = self.to_q(hidden_states) + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + # compute attentions + attn_weights = jnp.einsum("b i d, b j d->b i j", q, k) + attn_weights = attn_weights * self.scale + attn_weights = nn.softmax(attn_weights, axis=2) + + ## attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.to_out(hidden_states) + return hidden_states + + +class FlaxBasicTransformerBlock(nn.Module): + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # self attention + self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + # cross attention + self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + hidden_states = self.self_attn(self.norm1(hidden_states)) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.cross_attn(self.norm2(hidden_states), context) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxSpatialTransformer(nn.Module): + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + for _ in range(self.depth) + ] + + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + # import ipdb; ipdb.set_trace() + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + + # hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1)) + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context) + + hidden_states = hidden_states.reshape(batch, height, width, channels) + # hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxGluFeedForward(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.dense1(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + hidden_states = hidden_linear * nn.gelu(hidden_gelu) + hidden_states = self.dense2(hidden_states) + return hidden_states diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py new file mode 100644 index 000000000000..63442ab997b4 --- /dev/null +++ b/src/diffusers/models/embeddings_flax.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +# This is like models.embeddings.get_timestep_embedding (PyTorch) but +# less general (only handles the case we currently need). +def get_sinusoidal_embeddings(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] tensor of positional embeddings. + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = jnp.exp(jnp.arange(half_dim) * -emb) + emb = timesteps[:, None] * emb[None, :] + emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) + return emb + + +class FlaxTimestepEmbedding(nn.Module): + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + dim: int = 32 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings(timesteps, self.dim) diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py new file mode 100644 index 000000000000..46ccee35adcc --- /dev/null +++ b/src/diffusers/models/resnet_flax.py @@ -0,0 +1,111 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py new file mode 100644 index 000000000000..85c1e965a951 --- /dev/null +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -0,0 +1,257 @@ +from typing import Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import FlaxModelMixin +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .unet_blocks_flax import ( + FlaxDownBlock2D, + FlaxCrossAttnDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, + FlaxCrossAttnUpBlock2D, +) + + +# Configuration - we may not need this any more +class FlaxUNet2DConfig(ConfigMixin): + def __init__( + self, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + **kwargs, + ): + super().__init__(**kwargs) + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.down_block_types = down_block_types + self.up_block_types = up_block_types + self.block_out_channels = block_out_channels + self.layers_per_block = layers_per_block + self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim + self.dropout = dropout + + +# This is TBD. We may not need the module + the class +class FlaxUNet2DModule(nn.Module): + config: FlaxUNet2DConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + config = self.config + + self.sample_size = config.sample_size + block_out_channels = config.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(config.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=config.dropout, + num_layers=config.layers_per_block, + attn_num_head_channels=config.attention_head_dim, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=config.dropout, + num_layers=config.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=config.dropout, + attn_num_head_channels=config.attention_head_dim, + dtype=self.dtype, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(config.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=config.layers_per_block + 1, + attn_num_head_channels=config.attention_head_dim, + add_upsample=not is_final_block, + dropout=config.dropout, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=config.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=config.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + config.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True): + # 1. time + # broadcast to batch dimension + # timesteps = jnp.broadcast_to(timesteps, (sample.shape[0],) + timesteps.shape) + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states) + else: + sample, res_samples = down_block(sample, t_emb) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states) + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): + module_class = FlaxUNet2DModule + config_class = FlaxUNet2DConfig + base_model_prefix = "model" + module_class: nn.Module = None + + def __init__( + self, + config: FlaxUNet2DConfig, + input_shape: Tuple = (1, 32, 32, 4), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + ): + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + jnp.array(sample), + jnp.array(timesteps, dtype=jnp.int32), + encoder_hidden_states, + not train, + rngs=rngs, + ) + + +# class UNet2D(UNet2DPretrainedModel): +# module_class = UNet2DModule diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py new file mode 100644 index 000000000000..5de83bb2a559 --- /dev/null +++ b/src/diffusers/models/unet_blocks_flax.py @@ -0,0 +1,263 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import flax.linen as nn +import jax.numpy as jnp + +from .attention_flax import FlaxAttentionBlock, FlaxSpatialTransformer +from .resnet_flax import FlaxDownsample2D, FlaxUpsample2D, FlaxResnetBlock2D + +class FlaxCrossAttnDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxSpatialTransformer( + in_channels=self.in_channels, + n_heads=self.attn_num_head_channels, + d_head=self.in_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states From c3fdbf95320f893701f05cd5ec2ec2f906f36bca Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:39:30 +0200 Subject: [PATCH 2/3] Remove FlaxUNet2DConfig class. --- .../models/unet_2d_condition_flax.py | 127 ++++++++++-------- 1 file changed, 68 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 85c1e965a951..d3b54cf9d743 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -17,45 +17,25 @@ ) -# Configuration - we may not need this any more -class FlaxUNet2DConfig(ConfigMixin): - def __init__( - self, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels=(224, 448, 672, 896), - layers_per_block=2, - attention_head_dim=8, - cross_attention_dim=768, - dropout=0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.sample_size = sample_size - self.in_channels = in_channels - self.out_channels = out_channels - self.down_block_types = down_block_types - self.up_block_types = up_block_types - self.block_out_channels = block_out_channels - self.layers_per_block = layers_per_block - self.attention_head_dim = attention_head_dim - self.cross_attention_dim = cross_attention_dim - self.dropout = dropout - - # This is TBD. We may not need the module + the class class FlaxUNet2DModule(nn.Module): - config: FlaxUNet2DConfig + # config args + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + + # model args dtype: jnp.dtype = jnp.float32 def setup(self): - config = self.config - - self.sample_size = config.sample_size - block_out_channels = config.block_out_channels + block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # input @@ -74,7 +54,7 @@ def setup(self): # down down_blocks = [] output_channel = block_out_channels[0] - for i, down_block_type in enumerate(config.down_block_types): + for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 @@ -83,9 +63,9 @@ def setup(self): down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, - dropout=config.dropout, - num_layers=config.layers_per_block, - attn_num_head_channels=config.attention_head_dim, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=self.attention_head_dim, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -93,8 +73,8 @@ def setup(self): down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, - dropout=config.dropout, - num_layers=config.layers_per_block, + dropout=self.dropout, + num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -105,8 +85,8 @@ def setup(self): # mid self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], - dropout=config.dropout, - attn_num_head_channels=config.attention_head_dim, + dropout=self.dropout, + attn_num_head_channels=self.attention_head_dim, dtype=self.dtype, ) @@ -114,7 +94,7 @@ def setup(self): up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(config.up_block_types): + for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] @@ -126,10 +106,10 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - num_layers=config.layers_per_block + 1, - attn_num_head_channels=config.attention_head_dim, + num_layers=self.layers_per_block + 1, + attn_num_head_channels=self.attention_head_dim, add_upsample=not is_final_block, - dropout=config.dropout, + dropout=self.dropout, dtype=self.dtype, ) else: @@ -137,9 +117,9 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - num_layers=config.layers_per_block + 1, + num_layers=self.layers_per_block + 1, add_upsample=not is_final_block, - dropout=config.dropout, + dropout=self.dropout, dtype=self.dtype, ) @@ -150,7 +130,7 @@ def setup(self): # out self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv_out = nn.Conv( - config.out_channels, + self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), @@ -181,8 +161,8 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) # 5. up for up_block in self.up_blocks: - res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :] - down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)] + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, @@ -202,29 +182,58 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): - module_class = FlaxUNet2DModule - config_class = FlaxUNet2DConfig base_model_prefix = "model" - module_class: nn.Module = None + module_class = FlaxUNet2DModule + @register_to_config def __init__( self, - config: FlaxUNet2DConfig, + # config args + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + + # model args - to be ignored for config input_shape: Tuple = (1, 32, 32, 4), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + module = self.module_class( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + dropout=dropout, + dtype=dtype, **kwargs) + super().__init__( + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init + ) + # Note: input_shape is ignored def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors - sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) + sample_shape = (1, self.module.sample_size, self.module.sample_size, self.module.in_channels) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) - encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32) + encoder_hidden_states = jnp.zeros((1, 1, self.module.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} From 1067e3415527ea442b871708cf085170cc15dd0e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:45:33 +0200 Subject: [PATCH 3/3] ignore_for_config non-config args. --- src/diffusers/models/unet_2d_condition_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d3b54cf9d743..249390a16294 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -184,6 +184,7 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): base_model_prefix = "model" module_class = FlaxUNet2DModule + ignore_for_config = ["input_shape", "seed", "dtype", "_do_init"] @register_to_config def __init__( @@ -200,7 +201,7 @@ def __init__( cross_attention_dim=768, dropout=0.1, - # model args - to be ignored for config + # model args input_shape: Tuple = (1, 32, 32, 4), seed: int = 0, dtype: jnp.dtype = jnp.float32,