diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index db8f4fd17297..eb5067c37700 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed( pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - freqs = freqs.to(pos.device) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox