Skip to content

Commit 2ef9bdd

Browse files
kashifpatrickvonplatenanton-l
authored
Music Spectrogram diffusion pipeline (#1044)
* initial TokenEncoder and ContinuousEncoder * initial modules * added ContinuousContextTransformer * fix copy paste error * use numpy for get_sequence_length * initial terminal relative positional encodings * fix weights keys * fix assert * cross attend style: concat encodings * make style * concat once * fix formatting * Initial SpectrogramPipeline * fix input_tokens * make style * added mel output * ignore weights for config * move mel to numpy * import pipeline * fix class names and import * moved models to models folder * import ContinuousContextTransformer and SpectrogramDiffusionPipeline * initial spec diffusion converstion script * renamed config to t5config * added weight loading * use arguments instead of t5config * broadcast noise time to batch dim * fix call * added scale_to_features * fix weights * transpose laynorm weight * scale is a vector * scale the query outputs * added comment * undo scaling * undo depth_scaling * inital get_extended_attention_mask * attention_mask is none in self-attention * cleanup * manually invert attention * nn.linear need bias=False * added T5LayerFFCond * remove to fix conflict * make style and dummy * remove unsed variables * remove predict_epsilon * Move accelerate to a soft-dependency (#1134) * finish * finish * Update src/diffusers/modeling_utils.py * Update src/diffusers/pipeline_utils.py Co-authored-by: Anton Lozhkov <[email protected]> * more fixes * fix Co-authored-by: Anton Lozhkov <[email protected]> * fix order * added initial midi to note token data pipeline * added int to int tokenizer * remove duplicate * added logic for segments * add melgan to pipeline * move autoregressive gen into pipeline * added note_representation_processor_chain * fix dtypes * remove immutabledict req * initial doc * use np.where * require note_seq * fix typo * update dependency * added note-seq to test * added is_note_seq_available * fix import * added toc * added example usage * undo for now * moved docs * fix merge * fix imports * predict first segment * avoid un-needed copy to and from cpu * make style * Copyright * fix style * add test and fix inference steps * remove bogus files * reorder models * up * remove transformers dependency * make work with diffusers cross attention * clean more * remove @ * improve further * up * uP * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * loop over all tokens * make style * Added a section on the model * fix formatting * grammer * formatting * make fix-copies * Update src/diffusers/pipelines/__init__.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * added callback ad optional ionnx * do not squeeze batch dim * clean up more * upload * convert jax to nnumpy * make style * fix warning * make fix-copies * fix warning * add initial fast tests * add initial pipeline_params * eval mode due to dropout * skip batch tests as pipeline runs on a single file * make style * fix relative path * fix doc tests * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update docs/source/en/api/pipelines/spectrogram_diffusion.mdx Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * add MidiProcessor * format * fix org * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * make style * pin protobuf to <4 * fix formatting * white space * tensorboard needs protobuf --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 14e3a28 commit 2ef9bdd

24 files changed

+2003
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@
158158
title: Score SDE VE
159159
- local: api/pipelines/semantic_stable_diffusion
160160
title: Semantic Guidance
161+
- local: api/pipelines/spectrogram_diffusion
162+
title: "Spectrogram Diffusion"
161163
- sections:
162164
- local: api/pipelines/stable_diffusion/overview
163165
title: Overview
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Multi-instrument Music Synthesis with Spectrogram Diffusion
14+
15+
## Overview
16+
17+
[Spectrogram Diffusion](https://arxiv.org/abs/2206.05408) by Curtis Hawthorne, Ian Simon, Adam Roberts, Neil Zeghidour, Josh Gardner, Ethan Manilow, and Jesse Engel.
18+
19+
An ideal music synthesizer should be both interactive and expressive, generating high-fidelity audio in realtime for arbitrary combinations of instruments and notes. Recent neural synthesizers have exhibited a tradeoff between domain-specific models that offer detailed control of only specific instruments, or raw waveform models that can train on any music but with minimal control and slow generation. In this work, we focus on a middle ground of neural synthesizers that can generate audio from MIDI sequences with arbitrary combinations of instruments in realtime. This enables training on a wide range of transcription datasets with a single model, which in turn offers note-level control of composition and instrumentation across a wide range of instruments. We use a simple two-stage process: MIDI to spectrograms with an encoder-decoder Transformer, then spectrograms to audio with a generative adversarial network (GAN) spectrogram inverter. We compare training the decoder as an autoregressive model and as a Denoising Diffusion Probabilistic Model (DDPM) and find that the DDPM approach is superior both qualitatively and as measured by audio reconstruction and Fréchet distance metrics. Given the interactivity and generality of this approach, we find this to be a promising first step towards interactive and expressive neural synthesis for arbitrary combinations of instruments and notes.
20+
21+
The original codebase of this implementation can be found at [magenta/music-spectrogram-diffusion](https://github.com/magenta/music-spectrogram-diffusion).
22+
23+
## Model
24+
25+
![img](https://storage.googleapis.com/music-synthesis-with-spectrogram-diffusion/architecture.png)
26+
27+
As depicted above the model takes as input a MIDI file and tokenizes it into a sequence of 5 second intervals. Each tokenized interval then together with positional encodings is passed through the Note Encoder and its representation is concatenated with the previous window's generated spectrogram representation obtained via the Context Encoder. For the initial 5 second window this is set to zero. The resulting context is then used as conditioning to sample the denoised Spectrogram from the MIDI window and we concatenate this spectrogram to the final output as well as use it for the context of the next MIDI window. The process repeats till we have gone over all the MIDI inputs. Finally a MelGAN decoder converts the potentially long spectrogram to audio which is the final result of this pipeline.
28+
29+
## Available Pipelines:
30+
31+
| Pipeline | Tasks | Colab
32+
|---|---|:---:|
33+
| [pipeline_spectrogram_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion) | *Unconditional Audio Generation* | - |
34+
35+
36+
## Example usage
37+
38+
```python
39+
from diffusers import SpectrogramDiffusionPipeline, MidiProcessor
40+
41+
pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
42+
pipe = pipe.to("cuda")
43+
processor = MidiProcessor()
44+
45+
# Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid
46+
output = pipe(processor("beethoven_hammerklavier_2.mid"))
47+
48+
audio = output.audios[0]
49+
```
50+
51+
## SpectrogramDiffusionPipeline
52+
[[autodoc]] SpectrogramDiffusionPipeline
53+
- all
54+
- __call__
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import os
4+
5+
import jax as jnp
6+
import numpy as onp
7+
import torch
8+
import torch.nn as nn
9+
from music_spectrogram_diffusion import inference
10+
from t5x import checkpoints
11+
12+
from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline
13+
from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
14+
15+
16+
MODEL = "base_with_context"
17+
18+
19+
def load_notes_encoder(weights, model):
20+
model.token_embedder.weight = nn.Parameter(torch.FloatTensor(weights["token_embedder"]["embedding"]))
21+
model.position_encoding.weight = nn.Parameter(
22+
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
23+
)
24+
for lyr_num, lyr in enumerate(model.encoders):
25+
ly_weight = weights[f"layers_{lyr_num}"]
26+
lyr.layer[0].layer_norm.weight = nn.Parameter(
27+
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
28+
)
29+
30+
attention_weights = ly_weight["attention"]
31+
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
32+
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
33+
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
34+
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
35+
36+
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
37+
38+
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
39+
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
40+
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
41+
42+
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
43+
return model
44+
45+
46+
def load_continuous_encoder(weights, model):
47+
model.input_proj.weight = nn.Parameter(torch.FloatTensor(weights["input_proj"]["kernel"].T))
48+
49+
model.position_encoding.weight = nn.Parameter(
50+
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
51+
)
52+
53+
for lyr_num, lyr in enumerate(model.encoders):
54+
ly_weight = weights[f"layers_{lyr_num}"]
55+
attention_weights = ly_weight["attention"]
56+
57+
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
58+
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
59+
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
60+
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
61+
lyr.layer[0].layer_norm.weight = nn.Parameter(
62+
torch.FloatTensor(ly_weight["pre_attention_layer_norm"]["scale"])
63+
)
64+
65+
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
66+
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
67+
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
68+
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
69+
70+
model.layer_norm.weight = nn.Parameter(torch.FloatTensor(weights["encoder_norm"]["scale"]))
71+
72+
return model
73+
74+
75+
def load_decoder(weights, model):
76+
model.conditioning_emb[0].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense0"]["kernel"].T))
77+
model.conditioning_emb[2].weight = nn.Parameter(torch.FloatTensor(weights["time_emb_dense1"]["kernel"].T))
78+
79+
model.position_encoding.weight = nn.Parameter(
80+
torch.FloatTensor(weights["Embed_0"]["embedding"]), requires_grad=False
81+
)
82+
83+
model.continuous_inputs_projection.weight = nn.Parameter(
84+
torch.FloatTensor(weights["continuous_inputs_projection"]["kernel"].T)
85+
)
86+
87+
for lyr_num, lyr in enumerate(model.decoders):
88+
ly_weight = weights[f"layers_{lyr_num}"]
89+
lyr.layer[0].layer_norm.weight = nn.Parameter(
90+
torch.FloatTensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
91+
)
92+
93+
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
94+
torch.FloatTensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
95+
)
96+
97+
attention_weights = ly_weight["self_attention"]
98+
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
99+
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
100+
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
101+
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
102+
103+
attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
104+
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.FloatTensor(attention_weights["query"]["kernel"].T))
105+
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.FloatTensor(attention_weights["key"]["kernel"].T))
106+
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.FloatTensor(attention_weights["value"]["kernel"].T))
107+
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.FloatTensor(attention_weights["out"]["kernel"].T))
108+
lyr.layer[1].layer_norm.weight = nn.Parameter(
109+
torch.FloatTensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
110+
)
111+
112+
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.FloatTensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
113+
lyr.layer[2].film.scale_bias.weight = nn.Parameter(
114+
torch.FloatTensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
115+
)
116+
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
117+
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
118+
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.FloatTensor(ly_weight["mlp"]["wo"]["kernel"].T))
119+
120+
model.decoder_norm.weight = nn.Parameter(torch.FloatTensor(weights["decoder_norm"]["scale"]))
121+
122+
model.spec_out.weight = nn.Parameter(torch.FloatTensor(weights["spec_out_dense"]["kernel"].T))
123+
124+
return model
125+
126+
127+
def main(args):
128+
t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)
129+
t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)
130+
131+
gin_overrides = [
132+
"from __gin__ import dynamic_registration",
133+
"from music_spectrogram_diffusion.models.diffusion import diffusion_utils",
134+
"diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",
135+
"diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",
136+
]
137+
138+
gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")
139+
gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)
140+
synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)
141+
142+
scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")
143+
144+
notes_encoder = SpectrogramNotesEncoder(
145+
max_length=synth_model.sequence_length["inputs"],
146+
vocab_size=synth_model.model.module.config.vocab_size,
147+
d_model=synth_model.model.module.config.emb_dim,
148+
dropout_rate=synth_model.model.module.config.dropout_rate,
149+
num_layers=synth_model.model.module.config.num_encoder_layers,
150+
num_heads=synth_model.model.module.config.num_heads,
151+
d_kv=synth_model.model.module.config.head_dim,
152+
d_ff=synth_model.model.module.config.mlp_dim,
153+
feed_forward_proj="gated-gelu",
154+
)
155+
156+
continuous_encoder = SpectrogramContEncoder(
157+
input_dims=synth_model.audio_codec.n_dims,
158+
targets_context_length=synth_model.sequence_length["targets_context"],
159+
d_model=synth_model.model.module.config.emb_dim,
160+
dropout_rate=synth_model.model.module.config.dropout_rate,
161+
num_layers=synth_model.model.module.config.num_encoder_layers,
162+
num_heads=synth_model.model.module.config.num_heads,
163+
d_kv=synth_model.model.module.config.head_dim,
164+
d_ff=synth_model.model.module.config.mlp_dim,
165+
feed_forward_proj="gated-gelu",
166+
)
167+
168+
decoder = T5FilmDecoder(
169+
input_dims=synth_model.audio_codec.n_dims,
170+
targets_length=synth_model.sequence_length["targets_context"],
171+
max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,
172+
d_model=synth_model.model.module.config.emb_dim,
173+
num_layers=synth_model.model.module.config.num_decoder_layers,
174+
num_heads=synth_model.model.module.config.num_heads,
175+
d_kv=synth_model.model.module.config.head_dim,
176+
d_ff=synth_model.model.module.config.mlp_dim,
177+
dropout_rate=synth_model.model.module.config.dropout_rate,
178+
)
179+
180+
notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)
181+
continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)
182+
decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)
183+
184+
melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")
185+
186+
pipe = SpectrogramDiffusionPipeline(
187+
notes_encoder=notes_encoder,
188+
continuous_encoder=continuous_encoder,
189+
decoder=decoder,
190+
scheduler=scheduler,
191+
melgan=melgan,
192+
)
193+
if args.save:
194+
pipe.save_pretrained(args.output_path)
195+
196+
197+
if __name__ == "__main__":
198+
parser = argparse.ArgumentParser()
199+
200+
parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")
201+
parser.add_argument(
202+
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
203+
)
204+
parser.add_argument(
205+
"--checkpoint_path",
206+
default=f"{MODEL}/checkpoint_500000",
207+
type=str,
208+
required=False,
209+
help="Path to the original jax model checkpoint.",
210+
)
211+
args = parser.parse_args()
212+
213+
main(args)

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@
9595
"Jinja2",
9696
"k-diffusion>=0.0.12",
9797
"librosa",
98+
"note-seq",
9899
"numpy",
99100
"parameterized",
101+
"protobuf>=3.20.3,<4",
100102
"pytest",
101103
"pytest-timeout",
102104
"pytest-xdist",
@@ -182,13 +184,14 @@ def run(self):
182184
extras = {}
183185
extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
184186
extras["docs"] = deps_list("hf-doc-builder")
185-
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
187+
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
186188
extras["test"] = deps_list(
187189
"compel",
188190
"datasets",
189191
"Jinja2",
190192
"k-diffusion",
191193
"librosa",
194+
"note-seq",
192195
"parameterized",
193196
"pytest",
194197
"pytest-timeout",

src/diffusers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
is_k_diffusion_available,
99
is_k_diffusion_version,
1010
is_librosa_available,
11+
is_note_seq_available,
1112
is_onnx_available,
1213
is_scipy_available,
1314
is_torch_available,
@@ -37,6 +38,7 @@
3738
ControlNetModel,
3839
ModelMixin,
3940
PriorTransformer,
41+
T5FilmDecoder,
4042
Transformer2DModel,
4143
UNet1DModel,
4244
UNet2DConditionModel,
@@ -172,6 +174,14 @@
172174
else:
173175
from .pipelines import AudioDiffusionPipeline, Mel
174176

177+
try:
178+
if not (is_torch_available() and is_note_seq_available()):
179+
raise OptionalDependencyNotAvailable()
180+
except OptionalDependencyNotAvailable:
181+
from .utils.dummy_torch_and_note_seq_objects import * # noqa F403
182+
else:
183+
from .pipelines import SpectrogramDiffusionPipeline
184+
175185
try:
176186
if not is_flax_available():
177187
raise OptionalDependencyNotAvailable()
@@ -205,3 +215,11 @@
205215
FlaxStableDiffusionInpaintPipeline,
206216
FlaxStableDiffusionPipeline,
207217
)
218+
219+
try:
220+
if not (is_note_seq_available()):
221+
raise OptionalDependencyNotAvailable()
222+
except OptionalDependencyNotAvailable:
223+
from .utils.dummy_note_seq_objects import * # noqa F403
224+
else:
225+
from .pipelines import MidiProcessor

src/diffusers/dependency_versions_table.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
"Jinja2": "Jinja2",
2020
"k-diffusion": "k-diffusion>=0.0.12",
2121
"librosa": "librosa",
22+
"note-seq": "note-seq",
2223
"numpy": "numpy",
2324
"parameterized": "parameterized",
25+
"protobuf": "protobuf>=3.20.3,<4",
2426
"pytest": "pytest",
2527
"pytest-timeout": "pytest-timeout",
2628
"pytest-xdist": "pytest-xdist",

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .dual_transformer_2d import DualTransformer2DModel
2222
from .modeling_utils import ModelMixin
2323
from .prior_transformer import PriorTransformer
24+
from .t5_film_transformer import T5FilmDecoder
2425
from .transformer_2d import Transformer2DModel
2526
from .unet_1d import UNet1DModel
2627
from .unet_2d import UNet2DModel

0 commit comments

Comments
 (0)