Skip to content

Commit b63419a

Browse files
authored
AudioDiffusionPipeline - fix encode method after config changes (#3114)
* config fixes * deprecate get_input_dims
1 parent eb29dba commit b63419a

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,6 @@ def __init__(
5151
super().__init__()
5252
self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
5353

54-
def get_input_dims(self) -> Tuple:
55-
"""Returns dimension of input image
56-
57-
Returns:
58-
`Tuple`: (height, width)
59-
"""
60-
input_module = self.vqvae if self.vqvae is not None else self.unet
61-
# For backwards compatibility
62-
sample_size = (
63-
(input_module.config.sample_size, input_module.config.sample_size)
64-
if type(input_module.config.sample_size) == int
65-
else input_module.config.sample_size
66-
)
67-
return sample_size
68-
6954
def get_default_steps(self) -> int:
7055
"""Returns default number of steps recommended for inference
7156
@@ -123,8 +108,6 @@ def __call__(
123108
# For backwards compatibility
124109
if type(self.unet.config.sample_size) == int:
125110
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
126-
input_dims = self.get_input_dims()
127-
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
128111
if noise is None:
129112
noise = randn_tensor(
130113
(
@@ -234,7 +217,7 @@ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
234217
sample = torch.Tensor(sample).to(self.device)
235218

236219
for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
237-
prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
220+
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
238221
alpha_prod_t = self.scheduler.alphas_cumprod[t]
239222
alpha_prod_t_prev = (
240223
self.scheduler.alphas_cumprod[prev_timestep]

0 commit comments

Comments
 (0)