From ed0c49b4b23849306e167d74625a2b4d6ad3205e Mon Sep 17 00:00:00 2001 From: Joseph Smidt Date: Thu, 8 Aug 2024 19:15:56 -0600 Subject: [PATCH] Update transformer_flux.py. Change float64 to float32 dtype=torch.float64 is overkill, and float64 is not defined for certain devices such as Apple Silicon mps. --- src/diffusers/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3983606e46ac..0d9c2d3663b3 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -38,7 +38,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape