Skip to content

Commit 3e0b8f1

Browse files
committed
fix up prompt embeds param for SD upscaling ONNX pipeline
1 parent d60e02a commit 3e0b8f1

File tree

1 file changed

+53
-35
lines changed

1 file changed

+53
-35
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ def __call__(
106106
[`schedulers.DDIMScheduler`], will be ignored for others.
107107
generator (`np.random.RandomState`, *optional*):
108108
A np.random.RandomState to make generation deterministic.
109-
latents TODO
109+
latents (`torch.FloatTensor`, *optional*):
110+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
111+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
112+
tensor will ge generated by sampling using the supplied random `generator`.
110113
prompt_embeds (`np.ndarray`, *optional*):
111114
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
112115
provided, text embeddings will be generated from `prompt` input argument.
@@ -271,45 +274,59 @@ def decode_latents(self, latents):
271274
image = image.transpose((0, 2, 3, 1))
272275
return image
273276

274-
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
275-
batch_size = len(prompt) if isinstance(prompt, list) else 1
277+
def _encode_prompt(
278+
self,
279+
prompt: Union[str, List[str]],
280+
device,
281+
num_images_per_prompt: Optional[int],
282+
do_classifier_free_guidance: bool,
283+
negative_prompt: Optional[str],
284+
prompt_embeds: Optional[torch.FloatTensor] = None,
285+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
286+
):
287+
if prompt is not None and isinstance(prompt, str):
288+
batch_size = 1
289+
elif prompt is not None and isinstance(prompt, list):
290+
batch_size = len(prompt)
291+
else:
292+
batch_size = prompt_embeds.shape[0]
276293

277-
text_inputs = self.tokenizer(
278-
prompt,
279-
padding="max_length",
280-
max_length=self.tokenizer.model_max_length,
281-
truncation=True,
282-
return_tensors="pt",
283-
)
284-
text_input_ids = text_inputs.input_ids
285-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
286-
287-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
288-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
289-
logger.warning(
290-
"The following part of your input was truncated because CLIP can only handle sequences up to"
291-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
294+
if prompt_embeds is None:
295+
text_inputs = self.tokenizer(
296+
prompt,
297+
padding="max_length",
298+
max_length=self.tokenizer.model_max_length,
299+
truncation=True,
300+
return_tensors="pt",
292301
)
302+
text_input_ids = text_inputs.input_ids
303+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
304+
305+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
306+
text_input_ids, untruncated_ids
307+
):
308+
removed_text = self.tokenizer.batch_decode(
309+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
310+
)
311+
logger.warning(
312+
"The following part of your input was truncated because CLIP can only handle sequences up to"
313+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
314+
)
293315

294-
# if hasattr(text_inputs, "attention_mask"):
295-
# attention_mask = text_inputs.attention_mask.to(device)
296-
# else:
297-
# attention_mask = None
298-
299-
# no positional arguments to text_encoder
300-
text_embeddings = self.text_encoder(
301-
input_ids=text_input_ids.int().to(device),
302-
# attention_mask=attention_mask,
303-
)
304-
text_embeddings = text_embeddings[0]
316+
# no positional arguments to text_encoder
317+
prompt_embeds = self.text_encoder(
318+
input_ids=text_input_ids.int().to(device),
319+
# attention_mask=attention_mask,
320+
)
321+
prompt_embeds = prompt_embeds[0]
305322

306-
bs_embed, seq_len, _ = text_embeddings.shape
323+
bs_embed, seq_len, _ = prompt_embeds.shape
307324
# duplicate text embeddings for each generation per prompt, using mps friendly method
308-
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
309-
text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
325+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
326+
prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
310327

311328
# get unconditional embeddings for classifier free guidance
312-
if do_classifier_free_guidance:
329+
if do_classifier_free_guidance and negative_prompt_embeds is None:
313330
uncond_tokens: List[str]
314331
if negative_prompt is None:
315332
uncond_tokens = [""] * batch_size
@@ -349,6 +366,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
349366
)
350367
uncond_embeddings = uncond_embeddings[0]
351368

369+
if do_classifier_free_guidance:
352370
seq_len = uncond_embeddings.shape[1]
353371
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
354372
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
@@ -357,6 +375,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
357375
# For classifier free guidance, we need to do two forward passes.
358376
# Here we concatenate the unconditional and text embeddings into a single batch
359377
# to avoid doing two forward passes
360-
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
378+
prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
361379

362-
return text_embeddings
380+
return prompt_embeds

0 commit comments

Comments
 (0)