Skip to content

Add Note2audio model #544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions scripts/convert_notes2audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import math

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub


module = hub.KerasLayer("https://tfhub.dev/google/soundstream/mel/decoder/music/1")

# 1. Convert the TF weights of SOUNDSTREAM to PyTorch
# This will give us the necessary vocoder


# 2. Convert JAX T5 weights to Pytorch using the transformers script
# This will give us the necessary encoder and decoder
# Then encoder corresponds to the note encoder and the decoder part is the spectrogram decoder

# 3. Convert eh Context Encoder weights to Pytorch
# The context encoder should be pretty straightforward to convert

# 4. Implement tests to make sure that the models work properly


SAMPLE_RATE = 16000
N_FFT = 1024
HOP_LENGTH = 320
WIN_LENGTH = 640
N_MEL_CHANNELS = 128
MEL_FMIN = 0.0
MEL_FMAX = int(SAMPLE_RATE // 2)
CLIP_VALUE_MIN = 1e-5
CLIP_VALUE_MAX = 1e8

MEL_BASIS = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=N_MEL_CHANNELS,
num_spectrogram_bins=N_FFT // 2 + 1,
sample_rate=SAMPLE_RATE,
lower_edge_hertz=MEL_FMIN,
upper_edge_hertz=MEL_FMAX,
)


def calculate_spectrogram(samples):
"""Calculate mel spectrogram using the parameters the model expects."""
fft = tf.signal.stft(
samples,
frame_length=WIN_LENGTH,
frame_step=HOP_LENGTH,
fft_length=N_FFT,
window_fn=tf.signal.hann_window,
pad_end=True,
)
fft_modulus = tf.abs(fft)

output = tf.matmul(fft_modulus, MEL_BASIS)

output = tf.clip_by_value(output, clip_value_min=CLIP_VALUE_MIN, clip_value_max=CLIP_VALUE_MAX)
output = tf.math.log(output)
return output


# Load a music sample from the GTZAN dataset.
gtzan = tfds.load("gtzan", split="train")
# Convert an example from int to float.
samples = tf.cast(next(iter(gtzan))["audio"] / 32768, dtype=tf.float32)
# Add batch dimension.
samples = tf.expand_dims(samples, axis=0)
# Compute a mel-spectrogram.
spectrogram = calculate_spectrogram(samples)
# Reconstruct the audio from a mel-spectrogram using a SoundStream decoder.
reconstructed_samples = module(spectrogram)
167 changes: 167 additions & 0 deletions src/diffusers/models/music_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# This file will contain the necessary class to build the notes2audio pipeline
# Note Encoder, Spectrogram Decoder and Context Encoder


import torch
import torch.nn as nn

from transformers import T5Config
from transformers.models.t5.modeling_t5 import T5Block, T5Stack


class FiLMLayer(nn.Module):
"""A simple FiLM layer for conditioning on the diffusion time embedding."""

def __init__(self, in_channels, out_channels) -> None:
super().__init__()
self.gamma = nn.Linear(in_channels, out_channels) # s
self.beta = nn.Linear(in_channels, out_channels) # t

def forward(self, hidden_states, conditioning_emb):
"""Updates the hidden states based on the conditioning embeddings.

Args:
hidden_states (`Tensor`): _description_
conditioning_emb (`Tensor`): _description_

Returns:
_type_: _description_
"""

beta = self.beta(conditioning_emb).unsqueeze(-1).unsqueeze(-1)
gamma = self.gamma(conditioning_emb).unsqueeze(-1).unsqueeze(-1)

hidden_states = hidden_states * (gamma + 1.0) + beta
return hidden_states


class ContextEncoder(nn.Module):
def __init__(self) -> None:
super().__init__()


class NoteTokenizer(nn.Module):
def __init__(self) -> None:
super().__init__()


class NoteEncoder(nn.Module):
def __init__(self) -> None:
super().__init__()


class SpectrogramDecoder(nn.Module):
def __init__(self) -> None:
super().__init__()


class TokenEncoder(nn.Module):
"""A stack of encoder layers."""

config: T5Config

def __call__(self, encoder_input_tokens, encoder_inputs_mask, deterministic):
cfg = self.config

assert encoder_input_tokens.ndim == 2 # [batch, length]

seq_length = encoder_input_tokens.shape[1]
inputs_positions = jnp.arange(seq_length)[None, :]

# [batch, length] -> [batch, length, emb_dim]
x = layers.Embed(
num_embeddings=cfg.vocab_size,
features=cfg.emb_dim,
dtype=cfg.dtype,
embedding_init=nn.initializers.normal(stddev=1.0),
one_hot=True,
name="token_embedder",
)(encoder_input_tokens.astype("int32"))

x += position_encoding_layer(config=cfg, max_length=seq_length)(inputs_positions)
x = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(x, deterministic=deterministic)
x = x.astype(cfg.dtype)

for lyr in range(cfg.num_encoder_layers):
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = EncoderLayer(config=cfg, name=f"layers_{lyr}")(
inputs=x, encoder_inputs_mask=encoder_inputs_mask, deterministic=deterministic
)
x = layers.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x)
x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
return x, encoder_inputs_mask


class ContinuousContextTransformer(nn.Module):
"""An encoder-decoder Transformer model with a second audio context encoder."""

config: T5Config

def setup(self):
cfg = self.config

self.token_encoder = TokenEncoder(config=cfg)
self.continuous_encoder = ContinuousEncoder(config=cfg)
self.decoder = Decoder(config=cfg)

def encode(self, input_tokens, continuous_inputs, continuous_mask, enable_dropout=True):
"""Applies Transformer encoder-branch on the inputs."""
assert input_tokens.ndim == 2 # (batch, length)
assert continuous_inputs.ndim == 3 # (batch, length, input_dims)

tokens_mask = input_tokens > 0

tokens_encoded, tokens_mask = self.token_encoder(
encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask, deterministic=not enable_dropout
)

continuous_encoded, continuous_mask = self.continuous_encoder(
encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask, deterministic=not enable_dropout
)

return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)]

def decode(self, encodings_and_masks, input_tokens, noise_time, enable_dropout=True):
"""Applies Transformer decoder-branch on encoded-input and target."""
logits = self.decoder(
encodings_and_masks=encodings_and_masks,
decoder_input_tokens=input_tokens,
decoder_noise_time=noise_time,
deterministic=not enable_dropout,
)
return logits.astype(self.config.dtype)

def __call__(
self,
encoder_input_tokens,
encoder_continuous_inputs,
encoder_continuous_mask,
decoder_input_tokens,
decoder_noise_time,
*,
enable_dropout: bool = True,
):
"""Applies Transformer model on the inputs.
Args:
encoder_input_tokens: input data to the encoder.
encoder_continuous_inputs: continuous inputs for the second encoder.
encoder_continuous_mask: mask for continuous inputs.
decoder_input_tokens: input token to the decoder.
decoder_noise_time: noise continuous time for diffusion.
enable_dropout: Ensables dropout if set to True.
Returns:
logits array from full transformer.
"""
encodings_and_masks = self.encode(
input_tokens=encoder_input_tokens,
continuous_inputs=encoder_continuous_inputs,
continuous_mask=encoder_continuous_mask,
enable_dropout=enable_dropout,
)

return self.decode(
encodings_and_masks=encodings_and_masks,
input_tokens=decoder_input_tokens,
noise_time=decoder_noise_time,
enable_dropout=enable_dropout,
)
132 changes: 132 additions & 0 deletions src/diffusers/models/vocoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# All the vocoders used in diffusions pipelines will be implemented here.
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin


# DiffSound Uses MelGAN
class MelGAN(nn.Module):
def __init__(
self,
):
super().__init__()
return


class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)

def forward(self, x):
return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)


class CausalConvTranspose1d(nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.causal_padding = (
self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0]
)

def forward(self, x, output_size=None):
if self.padding_mode != "zeros":
raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d")

assert isinstance(self.padding, tuple)
output_padding = self._output_padding(
x, output_size, self.stride, self.padding, self.kernel_size, self.dilation
)
return F.conv_transpose1d(
x, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation
)[..., : -self.causal_padding]


class SoundStreamResNet(nn.Module):
def __init__(self, in_channels, out_channels, dilation):
super().__init__()
self.dilation = dilation
self.causal_conv = nn.CausalConv1d(in_channels, out_channels, kernel_size=7, dilation=dilation)
self.conv_1d = nn.Conv1d(in_channels, out_channels, kernel_size=1)
self.act = nn.ELU()

def forward(self, hidden_states):
residuals = hidden_states
hidden_states = self.causal_conv(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.conv_1d(hidden_states)
return residuals + hidden_states


class SoundStreamDecoderBlock(nn.Module):
def __init__(self, out_channels, stride):
super().__init__()
self.project_in = CausalConvTranspose1d(
in_channels=2 * out_channels, out_channels=out_channels, kernel_size=2 * stride, stride=stride
)
self.act = nn.ELU()

self.resnet_blocks = nn.ModuleList(
[SoundStreamResNet(out_channels, out_channels, 512, dilation=3 ^ rate) for rate in range(3)]
)

def forward(self, hidden_states):
hidden_states = self.project_in(hidden_states)
hidden_states = self.act(hidden_states)
for resnet in self.resnet_blocks:
hidden_states = resnet(hidden_states)
return hidden_states


# notes2audio uses SoundStream
class SoundStreamVocoder(ModelMixin, ConfigMixin):
"""Residual VQ VAE model from `SoundStream: An End-to-End Neural Audio Codec`

Args:
in_channels (`int`): number of input channels. It corresponds to the number of spectrogram features
that are passed to the decoder to compute the raw audio.
ConfigMixin (_type_): _description_
"""

def __init__(self, in_channels=8, out_channels=1, strides=[8, 5, 4, 2], channel_factors=[8, 4, 2, 1]):
super().__init__()
self.act = nn.ELU()
self.bottleneck = CausalConv1d(in_channels=in_channels, out_channels=16 * out_channels, kernel_size=7)
self.decoder_blocks = nn.ModuleList(
SoundStreamDecoderBlock(out_channels=out_channels * channel_factors[i], stride=strides[i])
for i in range(4)
)
self.last_layer_conv = CausalConv1d(in_channels=out_channels, out_channels=1, kernel_size=7)
return

def decode(self, features):
"""Decodes features to audio.
Args:
features: Mel spectrograms, shape [batch, n_frames, n_dims].
Returns:
audio: Shape [batch, n_frames * hop_size]
"""
if self._decode_dither_amount > 0:
features += torch.random.normal(size=features.shape) * self._decode_dither_amount

hidden_states = self.bottleneck(features)
hidden_states = self.act(hidden_states)
for layer in self.decoder_blocks:
hidden_states = layer(hidden_states)
hidden_states = self.act(hidden_states)

audio = self.last_layer_conv(hidden_states)

return audio


# TODO @Arthur DiffSinger uses this as vocoder
class HiFiGAN(nn.Module):
def __init__(
self,
):
super().__init__()
return
17 changes: 17 additions & 0 deletions src/diffusers/pipelines/notes2audio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# TODO Follow the stable diffusion pipeline card


Goal of the the implementation :

```python

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("magenta/notes2audio_base_with_context")

midi_setup_file = "path/to/midi_file.midi"
pipeline(midi_setup_file).sample[0]



```
Empty file.
Loading