Skip to content

Commit d36933b

Browse files
sanchit-gandhipatrickvonplatenwilliamberman
authored
Add AudioLDM (huggingface#2232)
* Add AudioLDM * up * add vocoder * start unet * unconditional unet * clap, vocoder and vae * clean-up: conversion scripts * fix: conversion script token_type_ids * clean-up: pipeline docstring * tests: from SD * clean-up: cpu offload vocoder instead of safety checker * feat: adapt tests to audioldm * feat: add docs * clean-up: amend pipeline docstrings * clean-up: make style * clean-up: make fix-copies * fix: add doc path to toctree * clean-up: args for conversion script * clean-up: paths to checkpoints * fix: use conditional unet * clean-up: make style * fix: type hints for UNet * clean-up: docstring for UNet * clean-up: make style * clean-up: remove duplicate in docstring * clean-up: make style * clean-up: make fix-copies * clean-up: move imports to start in code snippet * fix: pass cross_attention_dim as a list/tuple to unet * clean-up: make fix-copies * fix: update checkpoint path * fix: unet cross_attention_dim in tests * film embeddings -> class embeddings * Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> * fix: unet film embed to use existing args * fix: unet tests to use existing args * fix: make style * fix: transformers import and version in init * clean-up: make style * Revert "clean-up: make style" This reverts commit 5d6d1f8. * clean-up: make style * clean-up: use pipeline tester mixin tests where poss * clean-up: skip attn slicing test * fix: add torch dtype to docs * fix: remove conversion script out of src * fix: remove .detach from 1d waveform * fix: reduce default num inf steps * fix: swap height/width -> audio_length_in_s * clean-up: make style * fix: remove nightly tests * fix: imports in conversion script * clean-up: slim-down to two slow tests * clean-up: slim-down fast tests * fix: batch consistent tests * clean-up: make style * clean-up: remove vae slicing fast test * clean-up: propagate changes to doc * fix: increase test tol to 1e-2 * clean-up: finish docs * clean-up: make style * feat: vocoder / VAE compatibility check * feat: possibly expand / cut audio waveform * fix: pipeline call signature test * fix: slow tests output len * clean-up: make style * make style --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]>
1 parent 544dc2f commit d36933b

File tree

7 files changed

+722
-24
lines changed

7 files changed

+722
-24
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
from .pipelines import (
113113
AltDiffusionImg2ImgPipeline,
114114
AltDiffusionPipeline,
115+
AudioLDMPipeline,
115116
CycleDiffusionPipeline,
116117
LDMTextToImagePipeline,
117118
PaintByExamplePipeline,

models/unet_2d_condition.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
8686
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
8787
If `None`, it will skip the normalization and activation layers in post-processing
8888
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
89-
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
89+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
90+
The dimension of the cross attention features.
9091
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
9192
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
9293
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
9394
class_embed_type (`str`, *optional*, defaults to None):
9495
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
95-
`"timestep"`, `"identity"`, or `"projection"`.
96+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
9697
num_class_embeds (`int`, *optional*, defaults to None):
9798
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
9899
class conditioning with `class_embed_type` equal to `None`.
@@ -106,6 +107,8 @@ class conditioning with `class_embed_type` equal to `None`.
106107
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
107108
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
108109
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
110+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
111+
embeddings with the class embeddings.
109112
"""
110113

111114
_supports_gradient_checkpointing = True
@@ -135,7 +138,7 @@ def __init__(
135138
act_fn: str = "silu",
136139
norm_num_groups: Optional[int] = 32,
137140
norm_eps: float = 1e-5,
138-
cross_attention_dim: int = 1280,
141+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
139142
attention_head_dim: Union[int, Tuple[int]] = 8,
140143
dual_cross_attention: bool = False,
141144
use_linear_projection: bool = False,
@@ -149,6 +152,7 @@ def __init__(
149152
conv_in_kernel: int = 3,
150153
conv_out_kernel: int = 3,
151154
projection_class_embeddings_input_dim: Optional[int] = None,
155+
class_embeddings_concat: bool = False,
152156
):
153157
super().__init__()
154158

@@ -175,6 +179,11 @@ def __init__(
175179
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
176180
)
177181

182+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
183+
raise ValueError(
184+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
185+
)
186+
178187
# input
179188
conv_in_padding = (conv_in_kernel - 1) // 2
180189
self.conv_in = nn.Conv2d(
@@ -228,6 +237,12 @@ def __init__(
228237
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
229238
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
230239
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
240+
elif class_embed_type == "simple_projection":
241+
if projection_class_embeddings_input_dim is None:
242+
raise ValueError(
243+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
244+
)
245+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
231246
else:
232247
self.class_embedding = None
233248

@@ -240,6 +255,17 @@ def __init__(
240255
if isinstance(attention_head_dim, int):
241256
attention_head_dim = (attention_head_dim,) * len(down_block_types)
242257

258+
if isinstance(cross_attention_dim, int):
259+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
260+
261+
if class_embeddings_concat:
262+
# The time embeddings are concatenated with the class embeddings. The dimension of the
263+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
264+
# regular time embeddings
265+
blocks_time_embed_dim = time_embed_dim * 2
266+
else:
267+
blocks_time_embed_dim = time_embed_dim
268+
243269
# down
244270
output_channel = block_out_channels[0]
245271
for i, down_block_type in enumerate(down_block_types):
@@ -252,12 +278,12 @@ def __init__(
252278
num_layers=layers_per_block,
253279
in_channels=input_channel,
254280
out_channels=output_channel,
255-
temb_channels=time_embed_dim,
281+
temb_channels=blocks_time_embed_dim,
256282
add_downsample=not is_final_block,
257283
resnet_eps=norm_eps,
258284
resnet_act_fn=act_fn,
259285
resnet_groups=norm_num_groups,
260-
cross_attention_dim=cross_attention_dim,
286+
cross_attention_dim=cross_attention_dim[i],
261287
attn_num_head_channels=attention_head_dim[i],
262288
downsample_padding=downsample_padding,
263289
dual_cross_attention=dual_cross_attention,
@@ -272,12 +298,12 @@ def __init__(
272298
if mid_block_type == "UNetMidBlock2DCrossAttn":
273299
self.mid_block = UNetMidBlock2DCrossAttn(
274300
in_channels=block_out_channels[-1],
275-
temb_channels=time_embed_dim,
301+
temb_channels=blocks_time_embed_dim,
276302
resnet_eps=norm_eps,
277303
resnet_act_fn=act_fn,
278304
output_scale_factor=mid_block_scale_factor,
279305
resnet_time_scale_shift=resnet_time_scale_shift,
280-
cross_attention_dim=cross_attention_dim,
306+
cross_attention_dim=cross_attention_dim[-1],
281307
attn_num_head_channels=attention_head_dim[-1],
282308
resnet_groups=norm_num_groups,
283309
dual_cross_attention=dual_cross_attention,
@@ -287,11 +313,11 @@ def __init__(
287313
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
288314
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
289315
in_channels=block_out_channels[-1],
290-
temb_channels=time_embed_dim,
316+
temb_channels=blocks_time_embed_dim,
291317
resnet_eps=norm_eps,
292318
resnet_act_fn=act_fn,
293319
output_scale_factor=mid_block_scale_factor,
294-
cross_attention_dim=cross_attention_dim,
320+
cross_attention_dim=cross_attention_dim[-1],
295321
attn_num_head_channels=attention_head_dim[-1],
296322
resnet_groups=norm_num_groups,
297323
resnet_time_scale_shift=resnet_time_scale_shift,
@@ -307,6 +333,7 @@ def __init__(
307333
# up
308334
reversed_block_out_channels = list(reversed(block_out_channels))
309335
reversed_attention_head_dim = list(reversed(attention_head_dim))
336+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
310337
only_cross_attention = list(reversed(only_cross_attention))
311338

312339
output_channel = reversed_block_out_channels[0]
@@ -330,12 +357,12 @@ def __init__(
330357
in_channels=input_channel,
331358
out_channels=output_channel,
332359
prev_output_channel=prev_output_channel,
333-
temb_channels=time_embed_dim,
360+
temb_channels=blocks_time_embed_dim,
334361
add_upsample=add_upsample,
335362
resnet_eps=norm_eps,
336363
resnet_act_fn=act_fn,
337364
resnet_groups=norm_num_groups,
338-
cross_attention_dim=cross_attention_dim,
365+
cross_attention_dim=reversed_cross_attention_dim[i],
339366
attn_num_head_channels=reversed_attention_head_dim[i],
340367
dual_cross_attention=dual_cross_attention,
341368
use_linear_projection=use_linear_projection,
@@ -571,7 +598,11 @@ def forward(
571598
class_labels = self.time_proj(class_labels)
572599

573600
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
574-
emb = emb + class_emb
601+
602+
if self.config.class_embeddings_concat:
603+
emb = torch.cat([emb, class_emb], dim=-1)
604+
else:
605+
emb = emb + class_emb
575606

576607
# 2. pre-process
577608
sample = self.conv_in(sample)

pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
4545
else:
4646
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
47+
from .audioldm import AudioLDMPipeline
4748
from .latent_diffusion import LDMTextToImagePipeline
4849
from .paint_by_example import PaintByExamplePipeline
4950
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline

pipelines/audioldm/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ...utils import (
2+
OptionalDependencyNotAvailable,
3+
is_torch_available,
4+
is_transformers_available,
5+
is_transformers_version,
6+
)
7+
8+
9+
try:
10+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
11+
raise OptionalDependencyNotAvailable()
12+
except OptionalDependencyNotAvailable:
13+
from ...utils.dummy_torch_and_transformers_objects import (
14+
AudioLDMPipeline,
15+
)
16+
else:
17+
from .pipeline_audioldm import AudioLDMPipeline

0 commit comments

Comments
 (0)