Skip to content

Commit d0d5beb

Browse files
committed
[wip] vq diffusion classifier free sampling
1 parent 57525bb commit d0d5beb

File tree

1 file changed

+73
-30
lines changed

1 file changed

+73
-30
lines changed

src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,69 @@ def __init__(
7575
scheduler=scheduler,
7676
)
7777

78+
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
79+
batch_size = len(prompt) if isinstance(prompt, list) else 1
80+
81+
# get prompt text embeddings
82+
text_inputs = self.tokenizer(
83+
prompt,
84+
padding="max_length",
85+
max_length=self.tokenizer.model_max_length,
86+
return_tensors="pt",
87+
)
88+
text_input_ids = text_inputs.input_ids
89+
90+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
91+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
92+
logger.warning(
93+
"The following part of your input was truncated because CLIP can only handle sequences up to"
94+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
95+
)
96+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
97+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
98+
99+
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
100+
# While CLIP does normalize the pooled output of the text transformer when combining
101+
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
102+
#
103+
# CLIP normalizing the pooled output.
104+
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
105+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
106+
107+
# duplicate text embeddings for each generation per prompt
108+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
109+
110+
if do_classifier_free_guidance:
111+
uncond_tokens = [""] * batch_size
112+
113+
max_length = text_input_ids.shape[-1]
114+
uncond_input = self.tokenizer(
115+
uncond_tokens,
116+
padding="max_length",
117+
max_length=max_length,
118+
truncation=True,
119+
return_tensors="pt",
120+
)
121+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
122+
123+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
124+
seq_len = uncond_embeddings.shape[1]
125+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
126+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
127+
128+
# For classifier free guidance, we need to do two forward passes.
129+
# Here we concatenate the unconditional and text embeddings into a single batch
130+
# to avoid doing two forward passes
131+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
132+
133+
return text_embeddings
134+
78135
@torch.no_grad()
79136
def __call__(
80137
self,
81138
prompt: Union[str, List[str]],
82139
num_inference_steps: int = 100,
140+
guidance_scale: float = 5.0,
83141
truncation_rate: float = 1.0,
84142
num_images_per_prompt: int = 1,
85143
generator: Optional[torch.Generator] = None,
@@ -137,6 +195,12 @@ def __call__(
137195

138196
batch_size = batch_size * num_images_per_prompt
139197

198+
do_classifier_free_guidance = guidance_scale > 1.0
199+
200+
text_embeddings = self._encode_prompt(
201+
prompt, num_images_per_prompt, do_classifier_free_guidance
202+
)
203+
140204
if (callback_steps is None) or (
141205
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
142206
):
@@ -145,35 +209,6 @@ def __call__(
145209
f" {type(callback_steps)}."
146210
)
147211

148-
# get prompt text embeddings
149-
text_inputs = self.tokenizer(
150-
prompt,
151-
padding="max_length",
152-
max_length=self.tokenizer.model_max_length,
153-
return_tensors="pt",
154-
)
155-
text_input_ids = text_inputs.input_ids
156-
157-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
158-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
159-
logger.warning(
160-
"The following part of your input was truncated because CLIP can only handle sequences up to"
161-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
162-
)
163-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
164-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
165-
166-
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
167-
# While CLIP does normalize the pooled output of the text transformer when combining
168-
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
169-
#
170-
# CLIP normalizing the pooled output.
171-
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
172-
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
173-
174-
# duplicate text embeddings for each generation per prompt
175-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
176-
177212
# get the initial completely masked latents unless the user supplied it
178213

179214
latents_shape = (batch_size, self.transformer.num_latent_pixels)
@@ -198,9 +233,17 @@ def __call__(
198233
sample = latents
199234

200235
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
236+
# expand the sample if we are doing classifier free guidance
237+
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
238+
201239
# predict the un-noised image
202240
# model_output == `log_p_x_0`
203-
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
241+
model_output = self.transformer(latent_model_input, encoder_hidden_states=text_embeddings, timestep=t).sample
242+
243+
if do_classifier_free_guidance:
244+
model_output_uncond, model_output_text = model_output.chunk(2)
245+
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
246+
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)
204247

205248
model_output = self.truncate(model_output, truncation_rate)
206249

0 commit comments

Comments
 (0)