From f3ba28b668737f1ec3c081bdc1fabc4d7a7858e6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 29 Oct 2022 16:26:55 +0200 Subject: [PATCH 1/4] initial get_sinusoidal_embeddings --- src/diffusers/models/embeddings_flax.py | 51 +++++++++++++++++-------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index e2d607499c79..df73bdca0d58 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -17,23 +17,42 @@ import jax.numpy as jnp -# This is like models.embeddings.get_timestep_embedding (PyTorch) but -# less general (only handles the case we currently need). -def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): +def get_sinusoidal_embeddings( + timesteps: jnp.ndarray, + embedding_dim: int, + freq_shift: float = 1, + min_timescale: float = 1.0, + max_timescale: float = 1.0e4, + flip_sin_to_cos: bool = False, + scale: float = 1.0, +) -> jnp.ndarray: + """Returns the positional encoding (same as Tensor2Tensor). + Args: + timesteps: An array of shape [batch_size]. + embedding_dim: The number of output channels. + min_timescale: The smallest time unit (should probably be 0.0). + max_timescale: The largest time unit. + Returns: + a Tensor of timing signals [1, length, num_channels] """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + assert timesteps.ndim == 1 + assert embedding_dim % 2 == 0 + num_timescales = float(embedding_dim // 2) + log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) + inv_timescales = min_timescale * jnp.exp( + jnp.arange(num_timescales - freq_shift, dtype=jnp.float32) * -log_timescale_increment + ) + scaled_time = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) - :param timesteps: a 1-D tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] tensor of positional embeddings. - """ - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - freq_shift) - emb = jnp.exp(jnp.arange(half_dim) * -emb) - emb = timesteps[:, None] * emb[None, :] - emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) - return emb + # scale embeddings + scaled_time = scale * scaled_time + + if flip_sin_to_cos: + signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) + else: + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) + signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) + return signal class FlaxTimestepEmbedding(nn.Module): @@ -70,4 +89,4 @@ class FlaxTimesteps(nn.Module): @nn.compact def __call__(self, timesteps): - return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) + return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift) From 3500cf085312d671ca712b43d88f570f6fc4dab8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 31 Oct 2022 12:44:58 +0300 Subject: [PATCH 2/4] added asserts --- src/diffusers/models/embeddings_flax.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index df73bdca0d58..69bbe2beeff7 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -21,7 +21,7 @@ def get_sinusoidal_embeddings( timesteps: jnp.ndarray, embedding_dim: int, freq_shift: float = 1, - min_timescale: float = 1.0, + min_timescale: float = 1, max_timescale: float = 1.0e4, flip_sin_to_cos: bool = False, scale: float = 1.0, @@ -35,13 +35,11 @@ def get_sinusoidal_embeddings( Returns: a Tensor of timing signals [1, length, num_channels] """ - assert timesteps.ndim == 1 - assert embedding_dim % 2 == 0 + assert timesteps.ndim == 1, "Timesteps should be a 1d-array" + assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" num_timescales = float(embedding_dim // 2) log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) - inv_timescales = min_timescale * jnp.exp( - jnp.arange(num_timescales - freq_shift, dtype=jnp.float32) * -log_timescale_increment - ) + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) scaled_time = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # scale embeddings From 46b1567642f4fda253cd4423ce2803e97608862a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 31 Oct 2022 12:49:27 +0300 Subject: [PATCH 3/4] better var name --- src/diffusers/models/embeddings_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 69bbe2beeff7..e178f1409d6b 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -40,10 +40,10 @@ def get_sinusoidal_embeddings( num_timescales = float(embedding_dim // 2) log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) - scaled_time = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) + emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # scale embeddings - scaled_time = scale * scaled_time + scaled_time = scale * emb if flip_sin_to_cos: signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) From fd7c2896015d582c10b79c68df7a679e68a646a1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 31 Oct 2022 16:19:56 +0300 Subject: [PATCH 4/4] fix docs --- src/diffusers/models/embeddings_flax.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index e178f1409d6b..1e2272c1fe70 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -28,12 +28,13 @@ def get_sinusoidal_embeddings( ) -> jnp.ndarray: """Returns the positional encoding (same as Tensor2Tensor). Args: - timesteps: An array of shape [batch_size]. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. + timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + embedding_dim: The number of output channels. + min_timescale: The smallest time unit (should probably be 0.0). + max_timescale: The largest time unit. Returns: - a Tensor of timing signals [1, length, num_channels] + a Tensor of timing signals [N, num_channels] """ assert timesteps.ndim == 1, "Timesteps should be a 1d-array" assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"