From 1ca69e1a9df11ace6fde257308f7f06ee2f850b0 Mon Sep 17 00:00:00 2001 From: Daniel Hug Date: Wed, 7 Sep 2022 20:12:30 -0400 Subject: [PATCH] Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions --- src/diffusers/schedulers/scheduling_sde_ve.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index f6b0ba936eea..f23bf012b8e8 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -58,13 +58,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps=2000, - snr=0.15, - sigma_min=0.01, - sigma_max=1348, - sampling_eps=1e-5, - correct_steps=1, - tensor_format="pt", + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + tensor_format: str = "pt", ): # self.sigmas = None # self.discrete_sigmas = None @@ -78,7 +78,7 @@ def __init__( self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps, sampling_eps=None): + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": @@ -88,7 +88,9 @@ def set_timesteps(self, num_inference_steps, sampling_eps=None): else: raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps