diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 1e2272c1fe70..bf7d54b82ec2 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -88,4 +88,6 @@ class FlaxTimesteps(nn.Module): @nn.compact def __call__(self, timesteps): - return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift) + return get_sinusoidal_embeddings( + timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True + )