@@ -106,7 +106,10 @@ def __call__(
106
106
[`schedulers.DDIMScheduler`], will be ignored for others.
107
107
generator (`np.random.RandomState`, *optional*):
108
108
A np.random.RandomState to make generation deterministic.
109
- latents TODO
109
+ latents (`torch.FloatTensor`, *optional*):
110
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
111
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
112
+ tensor will ge generated by sampling using the supplied random `generator`.
110
113
prompt_embeds (`np.ndarray`, *optional*):
111
114
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
112
115
provided, text embeddings will be generated from `prompt` input argument.
@@ -271,45 +274,59 @@ def decode_latents(self, latents):
271
274
image = image .transpose ((0 , 2 , 3 , 1 ))
272
275
return image
273
276
274
- def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt ):
275
- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
277
+ def _encode_prompt (
278
+ self ,
279
+ prompt : Union [str , List [str ]],
280
+ device ,
281
+ num_images_per_prompt : Optional [int ],
282
+ do_classifier_free_guidance : bool ,
283
+ negative_prompt : Optional [str ],
284
+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
285
+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
286
+ ):
287
+ if prompt is not None and isinstance (prompt , str ):
288
+ batch_size = 1
289
+ elif prompt is not None and isinstance (prompt , list ):
290
+ batch_size = len (prompt )
291
+ else :
292
+ batch_size = prompt_embeds .shape [0 ]
276
293
277
- text_inputs = self .tokenizer (
278
- prompt ,
279
- padding = "max_length" ,
280
- max_length = self .tokenizer .model_max_length ,
281
- truncation = True ,
282
- return_tensors = "pt" ,
283
- )
284
- text_input_ids = text_inputs .input_ids
285
- untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
286
-
287
- if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
288
- removed_text = self .tokenizer .batch_decode (untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ])
289
- logger .warning (
290
- "The following part of your input was truncated because CLIP can only handle sequences up to"
291
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
294
+ if prompt_embeds is None :
295
+ text_inputs = self .tokenizer (
296
+ prompt ,
297
+ padding = "max_length" ,
298
+ max_length = self .tokenizer .model_max_length ,
299
+ truncation = True ,
300
+ return_tensors = "pt" ,
292
301
)
302
+ text_input_ids = text_inputs .input_ids
303
+ untruncated_ids = self .tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
304
+
305
+ if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
306
+ text_input_ids , untruncated_ids
307
+ ):
308
+ removed_text = self .tokenizer .batch_decode (
309
+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
310
+ )
311
+ logger .warning (
312
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
313
+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
314
+ )
293
315
294
- # if hasattr(text_inputs, "attention_mask"):
295
- # attention_mask = text_inputs.attention_mask.to(device)
296
- # else:
297
- # attention_mask = None
298
-
299
- # no positional arguments to text_encoder
300
- text_embeddings = self .text_encoder (
301
- input_ids = text_input_ids .int ().to (device ),
302
- # attention_mask=attention_mask,
303
- )
304
- text_embeddings = text_embeddings [0 ]
316
+ # no positional arguments to text_encoder
317
+ prompt_embeds = self .text_encoder (
318
+ input_ids = text_input_ids .int ().to (device ),
319
+ # attention_mask=attention_mask,
320
+ )
321
+ prompt_embeds = prompt_embeds [0 ]
305
322
306
- bs_embed , seq_len , _ = text_embeddings .shape
323
+ bs_embed , seq_len , _ = prompt_embeds .shape
307
324
# duplicate text embeddings for each generation per prompt, using mps friendly method
308
- text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt )
309
- text_embeddings = text_embeddings .reshape (bs_embed * num_images_per_prompt , seq_len , - 1 )
325
+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt )
326
+ prompt_embeds = prompt_embeds .reshape (bs_embed * num_images_per_prompt , seq_len , - 1 )
310
327
311
328
# get unconditional embeddings for classifier free guidance
312
- if do_classifier_free_guidance :
329
+ if do_classifier_free_guidance and negative_prompt_embeds is None :
313
330
uncond_tokens : List [str ]
314
331
if negative_prompt is None :
315
332
uncond_tokens = ["" ] * batch_size
@@ -349,6 +366,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
349
366
)
350
367
uncond_embeddings = uncond_embeddings [0 ]
351
368
369
+ if do_classifier_free_guidance :
352
370
seq_len = uncond_embeddings .shape [1 ]
353
371
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
354
372
uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt )
@@ -357,6 +375,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
357
375
# For classifier free guidance, we need to do two forward passes.
358
376
# Here we concatenate the unconditional and text embeddings into a single batch
359
377
# to avoid doing two forward passes
360
- text_embeddings = np .concatenate ([uncond_embeddings , text_embeddings ])
378
+ prompt_embeds = np .concatenate ([uncond_embeddings , prompt_embeds ])
361
379
362
- return text_embeddings
380
+ return prompt_embeds
0 commit comments