diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 00bc75f41be3..d0d02fb92e72 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -89,7 +89,7 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) -def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor: +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: @@ -121,8 +121,8 @@ def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tens def apply_freeu( - resolution_idx: int, hidden_states: torch.Tensor, res_hidden_states: torch.Tensor, **freeu_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs +) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.