@@ -85,25 +85,26 @@ def __init__(
8585 )
8686
8787 if trained_betas is not None :
88- self .betas = torch . from_numpy (trained_betas )
88+ self .betas = np . asarray (trained_betas )
8989 if beta_schedule == "linear" :
90- self .betas = torch .linspace (beta_start , beta_end , num_train_timesteps , dtype = torch .float32 )
90+ self .betas = np .linspace (beta_start , beta_end , num_train_timesteps , dtype = np .float32 )
9191 elif beta_schedule == "scaled_linear" :
9292 # this schedule is very specific to the latent diffusion model.
9393 self .betas = (
94- torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = torch .float32 ) ** 2
94+ np .linspace (beta_start ** 0.5 , beta_end ** 0.5 , num_train_timesteps , dtype = np .float32 ) ** 2
9595 )
9696 else :
9797 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
9898
99- self .alphas = 1.0 - self .betas
100- self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
99+ self .alphas = np . array ( 1.0 - self .betas , dtype = np . float32 )
100+ self .alphas_cumprod = np .cumprod (self .alphas , axis = 0 )
101101
102- self .sigmas = ((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5
102+ sigmas = ((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5
103+ self .sigmas = torch .from_numpy (sigmas )
103104
104105 # setable values
105106 self .num_inference_steps = None
106- self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ] # to be consistent has to be smaller than sigmas by 1
107+ self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ]
107108 self .derivatives = []
108109
109110 def get_lms_coefficient (self , order , t , current_order ):
@@ -146,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int):
146147 sigmas = (1 - frac ) * sigmas [low_idx ] + frac * sigmas [high_idx ]
147148 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
148149 self .sigmas = torch .from_numpy (sigmas )
150+ self .timesteps = timesteps
149151
150- self .timesteps = timesteps .astype (int )
151152 self .derivatives = []
152153
153154 def step (
0 commit comments