Skip to content

Commit a5eb7f4

Browse files
authored
[Examples] add speech to image pipeline example (#897)
* First draft * created the SpeechToImagePipeline class * Corrected speech_to_image_diffusion.py style * Added safety checker * Corrected style * Adding examples to README
1 parent ce7d966 commit a5eb7f4

File tree

2 files changed

+309
-0
lines changed

2 files changed

+309
-0
lines changed

examples/community/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
1313
| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
1414
| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
1515
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
16+
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
1617

1718
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
1819
```py
@@ -216,3 +217,50 @@ pipe.text2img(prompt,negative_prompt=neg_prompt, width=512, height=512, max_embe
216217
```
217218

218219
if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.
220+
221+
### Speech to Image
222+
223+
The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.
224+
225+
```Python
226+
import torch
227+
228+
import matplotlib.pyplot as plt
229+
from datasets import load_dataset
230+
from diffusers import DiffusionPipeline
231+
from transformers import (
232+
WhisperForConditionalGeneration,
233+
WhisperProcessor,
234+
)
235+
236+
237+
device = "cuda" if torch.cuda.is_available() else "cpu"
238+
239+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
240+
241+
audio_sample = ds[3]
242+
243+
text = audio_sample["text"].lower()
244+
speech_data = audio_sample["audio"]["array"]
245+
246+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
247+
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
248+
249+
diffuser_pipeline = DiffusionPipeline.from_pretrained(
250+
"CompVis/stable-diffusion-v1-4",
251+
custom_pipeline="speech_to_image_diffusion",
252+
speech_model=model,
253+
speech_processor=processor,
254+
revision="fp16",
255+
torch_dtype=torch.float16,
256+
)
257+
258+
diffuser_pipeline.enable_attention_slicing()
259+
diffuser_pipeline = diffuser_pipeline.to(device)
260+
261+
output = diffuser_pipeline(speech_data)
262+
plt.imshow(output.images[0])
263+
```
264+
This example produces the following image:
265+
266+
![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png)
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import inspect
2+
from typing import Callable, List, Optional, Union
3+
4+
import torch
5+
6+
from diffusers import (
7+
AutoencoderKL,
8+
DDIMScheduler,
9+
DiffusionPipeline,
10+
LMSDiscreteScheduler,
11+
PNDMScheduler,
12+
UNet2DConditionModel,
13+
)
14+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
15+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
16+
from diffusers.utils import logging
17+
from transformers import (
18+
CLIPFeatureExtractor,
19+
CLIPTextModel,
20+
CLIPTokenizer,
21+
WhisperForConditionalGeneration,
22+
WhisperProcessor,
23+
)
24+
25+
26+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
class SpeechToImagePipeline(DiffusionPipeline):
30+
def __init__(
31+
self,
32+
speech_model: WhisperForConditionalGeneration,
33+
speech_processor: WhisperProcessor,
34+
vae: AutoencoderKL,
35+
text_encoder: CLIPTextModel,
36+
tokenizer: CLIPTokenizer,
37+
unet: UNet2DConditionModel,
38+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
39+
safety_checker: StableDiffusionSafetyChecker,
40+
feature_extractor: CLIPFeatureExtractor,
41+
):
42+
super().__init__()
43+
44+
if safety_checker is None:
45+
logger.warn(
46+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
47+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
48+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
49+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
50+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
51+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
52+
)
53+
54+
self.register_modules(
55+
speech_model=speech_model,
56+
speech_processor=speech_processor,
57+
vae=vae,
58+
text_encoder=text_encoder,
59+
tokenizer=tokenizer,
60+
unet=unet,
61+
scheduler=scheduler,
62+
feature_extractor=feature_extractor,
63+
)
64+
65+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66+
if slice_size == "auto":
67+
slice_size = self.unet.config.attention_head_dim // 2
68+
self.unet.set_attention_slice(slice_size)
69+
70+
def disable_attention_slicing(self):
71+
self.enable_attention_slicing(None)
72+
73+
@torch.no_grad()
74+
def __call__(
75+
self,
76+
audio,
77+
sampling_rate=16_000,
78+
height: int = 512,
79+
width: int = 512,
80+
num_inference_steps: int = 50,
81+
guidance_scale: float = 7.5,
82+
negative_prompt: Optional[Union[str, List[str]]] = None,
83+
num_images_per_prompt: Optional[int] = 1,
84+
eta: float = 0.0,
85+
generator: Optional[torch.Generator] = None,
86+
latents: Optional[torch.FloatTensor] = None,
87+
output_type: Optional[str] = "pil",
88+
return_dict: bool = True,
89+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
90+
callback_steps: Optional[int] = 1,
91+
**kwargs,
92+
):
93+
inputs = self.speech_processor.feature_extractor(
94+
audio, return_tensors="pt", sampling_rate=sampling_rate
95+
).input_features.to(self.device)
96+
predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
97+
98+
prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
99+
0
100+
]
101+
102+
if isinstance(prompt, str):
103+
batch_size = 1
104+
elif isinstance(prompt, list):
105+
batch_size = len(prompt)
106+
else:
107+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
108+
109+
if height % 8 != 0 or width % 8 != 0:
110+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
111+
112+
if (callback_steps is None) or (
113+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
114+
):
115+
raise ValueError(
116+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
117+
f" {type(callback_steps)}."
118+
)
119+
120+
# get prompt text embeddings
121+
text_inputs = self.tokenizer(
122+
prompt,
123+
padding="max_length",
124+
max_length=self.tokenizer.model_max_length,
125+
return_tensors="pt",
126+
)
127+
text_input_ids = text_inputs.input_ids
128+
129+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
130+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
131+
logger.warning(
132+
"The following part of your input was truncated because CLIP can only handle sequences up to"
133+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
134+
)
135+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
136+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
137+
138+
# duplicate text embeddings for each generation per prompt, using mps friendly method
139+
bs_embed, seq_len, _ = text_embeddings.shape
140+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
141+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
142+
143+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
144+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
145+
# corresponds to doing no classifier free guidance.
146+
do_classifier_free_guidance = guidance_scale > 1.0
147+
# get unconditional embeddings for classifier free guidance
148+
if do_classifier_free_guidance:
149+
uncond_tokens: List[str]
150+
if negative_prompt is None:
151+
uncond_tokens = [""]
152+
elif type(prompt) is not type(negative_prompt):
153+
raise TypeError(
154+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
155+
f" {type(prompt)}."
156+
)
157+
elif isinstance(negative_prompt, str):
158+
uncond_tokens = [negative_prompt]
159+
elif batch_size != len(negative_prompt):
160+
raise ValueError(
161+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
162+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
163+
" the batch size of `prompt`."
164+
)
165+
else:
166+
uncond_tokens = negative_prompt
167+
168+
max_length = text_input_ids.shape[-1]
169+
uncond_input = self.tokenizer(
170+
uncond_tokens,
171+
padding="max_length",
172+
max_length=max_length,
173+
truncation=True,
174+
return_tensors="pt",
175+
)
176+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
177+
178+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
179+
seq_len = uncond_embeddings.shape[1]
180+
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
181+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
182+
183+
# For classifier free guidance, we need to do two forward passes.
184+
# Here we concatenate the unconditional and text embeddings into a single batch
185+
# to avoid doing two forward passes
186+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
187+
188+
# get the initial random noise unless the user supplied it
189+
190+
# Unlike in other pipelines, latents need to be generated in the target device
191+
# for 1-to-1 results reproducibility with the CompVis implementation.
192+
# However this currently doesn't work in `mps`.
193+
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
194+
latents_dtype = text_embeddings.dtype
195+
if latents is None:
196+
if self.device.type == "mps":
197+
# randn does not exist on mps
198+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
199+
self.device
200+
)
201+
else:
202+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
203+
else:
204+
if latents.shape != latents_shape:
205+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
206+
latents = latents.to(self.device)
207+
208+
# set timesteps
209+
self.scheduler.set_timesteps(num_inference_steps)
210+
211+
# Some schedulers like PNDM have timesteps as arrays
212+
# It's more optimized to move all timesteps to correct device beforehand
213+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
214+
215+
# scale the initial noise by the standard deviation required by the scheduler
216+
latents = latents * self.scheduler.init_noise_sigma
217+
218+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
219+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
220+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
221+
# and should be between [0, 1]
222+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
223+
extra_step_kwargs = {}
224+
if accepts_eta:
225+
extra_step_kwargs["eta"] = eta
226+
227+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
228+
# expand the latents if we are doing classifier free guidance
229+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
230+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
231+
232+
# predict the noise residual
233+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
234+
235+
# perform guidance
236+
if do_classifier_free_guidance:
237+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
238+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
239+
240+
# compute the previous noisy sample x_t -> x_t-1
241+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
242+
243+
# call the callback, if provided
244+
if callback is not None and i % callback_steps == 0:
245+
callback(i, t, latents)
246+
247+
latents = 1 / 0.18215 * latents
248+
image = self.vae.decode(latents).sample
249+
250+
image = (image / 2 + 0.5).clamp(0, 1)
251+
252+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
253+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
254+
255+
if output_type == "pil":
256+
image = self.numpy_to_pil(image)
257+
258+
if not return_dict:
259+
return image
260+
261+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)

0 commit comments

Comments
 (0)