diff --git a/scripts/convert_notes2audio.py b/scripts/convert_notes2audio.py new file mode 100644 index 000000000000..b989247b17e4 --- /dev/null +++ b/scripts/convert_notes2audio.py @@ -0,0 +1,71 @@ +import math + +import tensorflow as tf +import tensorflow_datasets as tfds +import tensorflow_hub as hub + + +module = hub.KerasLayer("https://tfhub.dev/google/soundstream/mel/decoder/music/1") + +# 1. Convert the TF weights of SOUNDSTREAM to PyTorch +# This will give us the necessary vocoder + + +# 2. Convert JAX T5 weights to Pytorch using the transformers script +# This will give us the necessary encoder and decoder +# Then encoder corresponds to the note encoder and the decoder part is the spectrogram decoder + +# 3. Convert eh Context Encoder weights to Pytorch +# The context encoder should be pretty straightforward to convert + +# 4. Implement tests to make sure that the models work properly + + +SAMPLE_RATE = 16000 +N_FFT = 1024 +HOP_LENGTH = 320 +WIN_LENGTH = 640 +N_MEL_CHANNELS = 128 +MEL_FMIN = 0.0 +MEL_FMAX = int(SAMPLE_RATE // 2) +CLIP_VALUE_MIN = 1e-5 +CLIP_VALUE_MAX = 1e8 + +MEL_BASIS = tf.signal.linear_to_mel_weight_matrix( + num_mel_bins=N_MEL_CHANNELS, + num_spectrogram_bins=N_FFT // 2 + 1, + sample_rate=SAMPLE_RATE, + lower_edge_hertz=MEL_FMIN, + upper_edge_hertz=MEL_FMAX, +) + + +def calculate_spectrogram(samples): + """Calculate mel spectrogram using the parameters the model expects.""" + fft = tf.signal.stft( + samples, + frame_length=WIN_LENGTH, + frame_step=HOP_LENGTH, + fft_length=N_FFT, + window_fn=tf.signal.hann_window, + pad_end=True, + ) + fft_modulus = tf.abs(fft) + + output = tf.matmul(fft_modulus, MEL_BASIS) + + output = tf.clip_by_value(output, clip_value_min=CLIP_VALUE_MIN, clip_value_max=CLIP_VALUE_MAX) + output = tf.math.log(output) + return output + + +# Load a music sample from the GTZAN dataset. +gtzan = tfds.load("gtzan", split="train") +# Convert an example from int to float. +samples = tf.cast(next(iter(gtzan))["audio"] / 32768, dtype=tf.float32) +# Add batch dimension. +samples = tf.expand_dims(samples, axis=0) +# Compute a mel-spectrogram. +spectrogram = calculate_spectrogram(samples) +# Reconstruct the audio from a mel-spectrogram using a SoundStream decoder. +reconstructed_samples = module(spectrogram) diff --git a/src/diffusers/models/music_transformer.py b/src/diffusers/models/music_transformer.py new file mode 100644 index 000000000000..6768bfdb5ca7 --- /dev/null +++ b/src/diffusers/models/music_transformer.py @@ -0,0 +1,167 @@ +# This file will contain the necessary class to build the notes2audio pipeline +# Note Encoder, Spectrogram Decoder and Context Encoder + + +import torch +import torch.nn as nn + +from transformers import T5Config +from transformers.models.t5.modeling_t5 import T5Block, T5Stack + + +class FiLMLayer(nn.Module): + """A simple FiLM layer for conditioning on the diffusion time embedding.""" + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.gamma = nn.Linear(in_channels, out_channels) # s + self.beta = nn.Linear(in_channels, out_channels) # t + + def forward(self, hidden_states, conditioning_emb): + """Updates the hidden states based on the conditioning embeddings. + + Args: + hidden_states (`Tensor`): _description_ + conditioning_emb (`Tensor`): _description_ + + Returns: + _type_: _description_ + """ + + beta = self.beta(conditioning_emb).unsqueeze(-1).unsqueeze(-1) + gamma = self.gamma(conditioning_emb).unsqueeze(-1).unsqueeze(-1) + + hidden_states = hidden_states * (gamma + 1.0) + beta + return hidden_states + + +class ContextEncoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class NoteTokenizer(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class NoteEncoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class SpectrogramDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class TokenEncoder(nn.Module): + """A stack of encoder layers.""" + + config: T5Config + + def __call__(self, encoder_input_tokens, encoder_inputs_mask, deterministic): + cfg = self.config + + assert encoder_input_tokens.ndim == 2 # [batch, length] + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = jnp.arange(seq_length)[None, :] + + # [batch, length] -> [batch, length, emb_dim] + x = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=True, + name="token_embedder", + )(encoder_input_tokens.astype("int32")) + + x += position_encoding_layer(config=cfg, max_length=seq_length)(inputs_positions) + x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(x, deterministic=deterministic) + x = x.astype(cfg.dtype) + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] + x = EncoderLayer(config=cfg, name=f"layers_{lyr}")( + inputs=x, encoder_inputs_mask=encoder_inputs_mask, deterministic=deterministic + ) + x = layers.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x) + x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + return x, encoder_inputs_mask + + +class ContinuousContextTransformer(nn.Module): + """An encoder-decoder Transformer model with a second audio context encoder.""" + + config: T5Config + + def setup(self): + cfg = self.config + + self.token_encoder = TokenEncoder(config=cfg) + self.continuous_encoder = ContinuousEncoder(config=cfg) + self.decoder = Decoder(config=cfg) + + def encode(self, input_tokens, continuous_inputs, continuous_mask, enable_dropout=True): + """Applies Transformer encoder-branch on the inputs.""" + assert input_tokens.ndim == 2 # (batch, length) + assert continuous_inputs.ndim == 3 # (batch, length, input_dims) + + tokens_mask = input_tokens > 0 + + tokens_encoded, tokens_mask = self.token_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask, deterministic=not enable_dropout + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask, deterministic=not enable_dropout + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time, enable_dropout=True): + """Applies Transformer decoder-branch on encoded-input and target.""" + logits = self.decoder( + encodings_and_masks=encodings_and_masks, + decoder_input_tokens=input_tokens, + decoder_noise_time=noise_time, + deterministic=not enable_dropout, + ) + return logits.astype(self.config.dtype) + + def __call__( + self, + encoder_input_tokens, + encoder_continuous_inputs, + encoder_continuous_mask, + decoder_input_tokens, + decoder_noise_time, + *, + enable_dropout: bool = True, + ): + """Applies Transformer model on the inputs. + Args: + encoder_input_tokens: input data to the encoder. + encoder_continuous_inputs: continuous inputs for the second encoder. + encoder_continuous_mask: mask for continuous inputs. + decoder_input_tokens: input token to the decoder. + decoder_noise_time: noise continuous time for diffusion. + enable_dropout: Ensables dropout if set to True. + Returns: + logits array from full transformer. + """ + encodings_and_masks = self.encode( + input_tokens=encoder_input_tokens, + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + enable_dropout=enable_dropout, + ) + + return self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=decoder_input_tokens, + noise_time=decoder_noise_time, + enable_dropout=enable_dropout, + ) diff --git a/src/diffusers/models/vocoders.py b/src/diffusers/models/vocoders.py new file mode 100644 index 000000000000..ac44c853b73a --- /dev/null +++ b/src/diffusers/models/vocoders.py @@ -0,0 +1,132 @@ +# All the vocoders used in diffusions pipelines will be implemented here. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + + +# DiffSound Uses MelGAN +class MelGAN(nn.Module): + def __init__( + self, + ): + super().__init__() + return + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + + def forward(self, x): + return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) + + +class CausalConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.causal_padding = ( + self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0] + ) + + def forward(self, x, output_size=None): + if self.padding_mode != "zeros": + raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") + + assert isinstance(self.padding, tuple) + output_padding = self._output_padding( + x, output_size, self.stride, self.padding, self.kernel_size, self.dilation + ) + return F.conv_transpose1d( + x, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation + )[..., : -self.causal_padding] + + +class SoundStreamResNet(nn.Module): + def __init__(self, in_channels, out_channels, dilation): + super().__init__() + self.dilation = dilation + self.causal_conv = nn.CausalConv1d(in_channels, out_channels, kernel_size=7, dilation=dilation) + self.conv_1d = nn.Conv1d(in_channels, out_channels, kernel_size=1) + self.act = nn.ELU() + + def forward(self, hidden_states): + residuals = hidden_states + hidden_states = self.causal_conv(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.conv_1d(hidden_states) + return residuals + hidden_states + + +class SoundStreamDecoderBlock(nn.Module): + def __init__(self, out_channels, stride): + super().__init__() + self.project_in = CausalConvTranspose1d( + in_channels=2 * out_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride + ) + self.act = nn.ELU() + + self.resnet_blocks = nn.ModuleList( + [SoundStreamResNet(out_channels, out_channels, 512, dilation=3 ^ rate) for rate in range(3)] + ) + + def forward(self, hidden_states): + hidden_states = self.project_in(hidden_states) + hidden_states = self.act(hidden_states) + for resnet in self.resnet_blocks: + hidden_states = resnet(hidden_states) + return hidden_states + + +# notes2audio uses SoundStream +class SoundStreamVocoder(ModelMixin, ConfigMixin): + """Residual VQ VAE model from `SoundStream: An End-to-End Neural Audio Codec` + + Args: + in_channels (`int`): number of input channels. It corresponds to the number of spectrogram features + that are passed to the decoder to compute the raw audio. + ConfigMixin (_type_): _description_ + """ + + def __init__(self, in_channels=8, out_channels=1, strides=[8, 5, 4, 2], channel_factors=[8, 4, 2, 1]): + super().__init__() + self.act = nn.ELU() + self.bottleneck = CausalConv1d(in_channels=in_channels, out_channels=16 * out_channels, kernel_size=7) + self.decoder_blocks = nn.ModuleList( + SoundStreamDecoderBlock(out_channels=out_channels * channel_factors[i], stride=strides[i]) + for i in range(4) + ) + self.last_layer_conv = CausalConv1d(in_channels=out_channels, out_channels=1, kernel_size=7) + return + + def decode(self, features): + """Decodes features to audio. + Args: + features: Mel spectrograms, shape [batch, n_frames, n_dims]. + Returns: + audio: Shape [batch, n_frames * hop_size] + """ + if self._decode_dither_amount > 0: + features += torch.random.normal(size=features.shape) * self._decode_dither_amount + + hidden_states = self.bottleneck(features) + hidden_states = self.act(hidden_states) + for layer in self.decoder_blocks: + hidden_states = layer(hidden_states) + hidden_states = self.act(hidden_states) + + audio = self.last_layer_conv(hidden_states) + + return audio + + +# TODO @Arthur DiffSinger uses this as vocoder +class HiFiGAN(nn.Module): + def __init__( + self, + ): + super().__init__() + return diff --git a/src/diffusers/pipelines/notes2audio/README.md b/src/diffusers/pipelines/notes2audio/README.md new file mode 100644 index 000000000000..294c867beaf1 --- /dev/null +++ b/src/diffusers/pipelines/notes2audio/README.md @@ -0,0 +1,17 @@ +# TODO Follow the stable diffusion pipeline card + + +Goal of the the implementation : + +```python + +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("magenta/notes2audio_base_with_context") + +midi_setup_file = "path/to/midi_file.midi" +pipeline(midi_setup_file).sample[0] + + + +``` diff --git a/src/diffusers/pipelines/notes2audio/__init__.py b/src/diffusers/pipelines/notes2audio/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py new file mode 100644 index 000000000000..9703b9532588 --- /dev/null +++ b/src/diffusers/pipelines/notes2audio/pipeline_notes2audio.py @@ -0,0 +1,79 @@ +from turtle import forward +from typing import List, Optional, Union + +import torch + +from transformers import T5Model + +from ...models import Notes2AudioModel, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline + + +class Notes2AudioPipeline(DiffusionPipeline): + r""" + Pipeline for notes(midi)-to-audio generation using music-spectrogram diffusion introduced by magenta in + notes2audio. + + Args: + decoder ([` `]): + Decoder model used to convert the hidden states to a mel spectrogram. Should take as an input the encoder + hidden states as well a the diffusion noise. Should be the soundstreal MELGan style decoder + context_encoder ([` `]): + Encoder used to create the context to smooth the transitions between adjacent audio frames. + note_encoder (` `): + model used to encode ?? + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `decoder` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__(self, spectrogram_decoder, context_encoder, note_encoder, vocoder, scheduler): + self.spectrogram_decoder = spectrogram_decoder + self.context_encoder = context_encoder + self.note_encoder = note_encoder + self.vocoder = vocoder + self.scheduler = scheduler + scheduler = scheduler.set_format("pt") + self.register_modules( + spectrogram_decoder=spectrogram_decoder, + context_encoder=context_encoder, + note_encoder=note_encoder, + scheduler=scheduler, + vocoder=vocoder, + ) + + def generation_step(self): + """ + Generate a single frame of audio which corresponds to 5 seconds. + + Args: + encoder_continuous_inputs (`torch.Tensor`): fields for context + encoder_continuous_mask (`torch.Tensor`): fields for context + encoder_input_tokens (`torch.Tensor`): fields for context + decoder_target_tokens (`torch.Tensor`): fields for context + diffusion_noise (`torch.Tensor`): fields for context + diffusion_noise_mask (`torch.Tensor`): fields for context + deterministic (`bool`): fields for context + **kwargs (`dict`): fields for context + + Returns: + `torch.Tensor`: The generated audio + + + """ + + @torch.no_grad() + def __call__( + self, + midi: Union[str, List[str]], + audio_length: int, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + return