From e23a58f361be7492962f1736609c39d21d08e746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 15 Jun 2025 16:53:03 +0300 Subject: [PATCH] Fix dimensionality in `apply_rotary_emb` functions' comments. --- src/diffusers/models/embeddings.py | 4 ++-- src/diffusers/models/transformers/transformer_ltx.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index cfc501c47ed9..6119a453666d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1199,11 +1199,11 @@ def apply_rotary_emb( if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen, CogView4 and Cosmos - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ae2418098f6..79781da8a740 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -481,7 +481,7 @@ def forward( def apply_rotary_emb(x, freqs): cos, sin = freqs - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out