Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flax.linen as nn
import jax.numpy as jnp


class FlaxAttentionBlock(nn.Module):
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5

self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)
self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)
self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)

self.to_out = nn.Dense(self.query_dim, dtype=self.dtype)

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context

q = self.to_q(hidden_states)
k = self.to_k(context)
v = self.to_v(context)

q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)

# compute attentions
attn_weights = jnp.einsum("b i d, b j d->b i j", q, k)
attn_weights = attn_weights * self.scale
attn_weights = nn.softmax(attn_weights, axis=2)

## attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.to_out(hidden_states)
return hidden_states


class FlaxBasicTransformerBlock(nn.Module):
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)

def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states))
hidden_states = hidden_states + residual

# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context)
hidden_states = hidden_states + residual

# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states))
hidden_states = hidden_states + residual

return hidden_states


class FlaxSpatialTransformer(nn.Module):
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

self.transformer_blocks = [
TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]

self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)

# hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1))
hidden_states = hidden_states.reshape(batch, height * width, channels)

for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context)

hidden_states = hidden_states.reshape(batch, height, width, channels)
# hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual

return hidden_states


class FlaxGluFeedForward(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
56 changes: 56 additions & 0 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import flax.linen as nn
import jax.numpy as jnp


# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb


class FlaxTimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb


class FlaxTimesteps(nn.Module):
dim: int = 32

@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)
111 changes: 111 additions & 0 deletions src/diffusers/models/resnet_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import flax.linen as nn
import jax
import jax.numpy as jnp


class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states


class FlaxDownsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)), # padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states):
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
hidden_states = self.conv(hidden_states)
return hidden_states


class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32

def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels

self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)

self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.dropout = nn.Dropout(self.dropout_prob)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut

self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, temb, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)

temb = self.time_emb_proj(nn.swish(temb))
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)

return hidden_states + residual
Loading