diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3168fd9a625f..1773ccd253c9 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -39,7 +39,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape