@@ -51,21 +51,6 @@ def __init__(
51
51
super ().__init__ ()
52
52
self .register_modules (unet = unet , scheduler = scheduler , mel = mel , vqvae = vqvae )
53
53
54
- def get_input_dims (self ) -> Tuple :
55
- """Returns dimension of input image
56
-
57
- Returns:
58
- `Tuple`: (height, width)
59
- """
60
- input_module = self .vqvae if self .vqvae is not None else self .unet
61
- # For backwards compatibility
62
- sample_size = (
63
- (input_module .config .sample_size , input_module .config .sample_size )
64
- if type (input_module .config .sample_size ) == int
65
- else input_module .config .sample_size
66
- )
67
- return sample_size
68
-
69
54
def get_default_steps (self ) -> int :
70
55
"""Returns default number of steps recommended for inference
71
56
@@ -123,8 +108,6 @@ def __call__(
123
108
# For backwards compatibility
124
109
if type (self .unet .config .sample_size ) == int :
125
110
self .unet .config .sample_size = (self .unet .config .sample_size , self .unet .config .sample_size )
126
- input_dims = self .get_input_dims ()
127
- self .mel .set_resolution (x_res = input_dims [1 ], y_res = input_dims [0 ])
128
111
if noise is None :
129
112
noise = randn_tensor (
130
113
(
@@ -234,7 +217,7 @@ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
234
217
sample = torch .Tensor (sample ).to (self .device )
235
218
236
219
for t in self .progress_bar (torch .flip (self .scheduler .timesteps , (0 ,))):
237
- prev_timestep = t - self .scheduler .num_train_timesteps // self .scheduler .num_inference_steps
220
+ prev_timestep = t - self .scheduler .config . num_train_timesteps // self .scheduler .num_inference_steps
238
221
alpha_prod_t = self .scheduler .alphas_cumprod [t ]
239
222
alpha_prod_t_prev = (
240
223
self .scheduler .alphas_cumprod [prev_timestep ]
0 commit comments