@@ -175,6 +175,7 @@ def __call__(
175175 width : Optional [int ] = 512 ,
176176 num_inference_steps : Optional [int ] = 50 ,
177177 guidance_scale : Optional [float ] = 7.5 ,
178+ num_images_per_prompt : Optional [int ] = 1 ,
178179 clip_guidance_scale : Optional [float ] = 100 ,
179180 clip_prompt : Optional [Union [str , List [str ]]] = None ,
180181 num_cutouts : Optional [int ] = 4 ,
@@ -203,6 +204,8 @@ def __call__(
203204 return_tensors = "pt" ,
204205 )
205206 text_embeddings = self .text_encoder (text_input .input_ids .to (self .device ))[0 ]
207+ # duplicate text embeddings for each generation per prompt
208+ text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
206209
207210 if clip_guidance_scale > 0 :
208211 if clip_prompt is not None :
@@ -217,6 +220,8 @@ def __call__(
217220 clip_text_input = text_input .input_ids .to (self .device )
218221 text_embeddings_clip = self .clip_model .get_text_features (clip_text_input )
219222 text_embeddings_clip = text_embeddings_clip / text_embeddings_clip .norm (p = 2 , dim = - 1 , keepdim = True )
223+ # duplicate text embeddings clip for each generation per prompt
224+ text_embeddings_clip = text_embeddings_clip .repeat_interleave (num_images_per_prompt , dim = 0 )
220225
221226 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
222227 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -225,10 +230,10 @@ def __call__(
225230 # get unconditional embeddings for classifier free guidance
226231 if do_classifier_free_guidance :
227232 max_length = text_input .input_ids .shape [- 1 ]
228- uncond_input = self .tokenizer (
229- ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "pt"
230- )
233+ uncond_input = self .tokenizer (["" ], padding = "max_length" , max_length = max_length , return_tensors = "pt" )
231234 uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
235+ # duplicate unconditional embeddings for each generation per prompt
236+ uncond_embeddings = uncond_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
232237
233238 # For classifier free guidance, we need to do two forward passes.
234239 # Here we concatenate the unconditional and text embeddings into a single batch
@@ -240,18 +245,20 @@ def __call__(
240245 # Unlike in other pipelines, latents need to be generated in the target device
241246 # for 1-to-1 results reproducibility with the CompVis implementation.
242247 # However this currently doesn't work in `mps`.
243- latents_device = "cpu" if self .device . type == "mps" else self . device
244- latents_shape = ( batch_size , self . unet . in_channels , height // 8 , width // 8 )
248+ latents_shape = ( batch_size * num_images_per_prompt , self .unet . in_channels , height // 8 , width // 8 )
249+ latents_dtype = text_embeddings . dtype
245250 if latents is None :
246- latents = torch .randn (
247- latents_shape ,
248- generator = generator ,
249- device = latents_device ,
250- )
251+ if self .device .type == "mps" :
252+ # randn does not exist on mps
253+ latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
254+ self .device
255+ )
256+ else :
257+ latents = torch .randn (latents_shape , generator = generator , device = self .device , dtype = latents_dtype )
251258 else :
252259 if latents .shape != latents_shape :
253260 raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
254- latents = latents .to (self .device )
261+ latents = latents .to (self .device )
255262
256263 # set timesteps
257264 accepts_offset = "offset" in set (inspect .signature (self .scheduler .set_timesteps ).parameters .keys ())
@@ -261,17 +268,17 @@ def __call__(
261268
262269 self .scheduler .set_timesteps (num_inference_steps , ** extra_set_kwargs )
263270
264- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265- if isinstance ( self . scheduler , LMSDiscreteScheduler ):
266- latents = latents * self .scheduler .sigmas [ 0 ]
271+ # Some schedulers like PNDM have timesteps as arrays
272+ # It's more optimized to move all timesteps to correct device beforehand
273+ timesteps_tensor = self .scheduler .timesteps . to ( self . device )
267274
268- for i , t in enumerate (self .progress_bar (self .scheduler .timesteps )):
275+ # scale the initial noise by the standard deviation required by the scheduler
276+ latents = latents * self .scheduler .init_noise_sigma
277+
278+ for i , t in enumerate (self .progress_bar (timesteps_tensor )):
269279 # expand the latents if we are doing classifier free guidance
270280 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
271- if isinstance (self .scheduler , LMSDiscreteScheduler ):
272- sigma = self .scheduler .sigmas [i ]
273- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
274- latent_model_input = latent_model_input / ((sigma ** 2 + 1 ) ** 0.5 )
281+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
275282
276283 # predict the noise residual
277284 noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
@@ -299,10 +306,7 @@ def __call__(
299306 )
300307
301308 # compute the previous noisy sample x_t -> x_t-1
302- if isinstance (self .scheduler , LMSDiscreteScheduler ):
303- latents = self .scheduler .step (noise_pred , i , latents ).prev_sample
304- else :
305- latents = self .scheduler .step (noise_pred , t , latents ).prev_sample
309+ latents = self .scheduler .step (noise_pred , t , latents ).prev_sample
306310
307311 # scale and decode the image latents with vae
308312 latents = 1 / 0.18215 * latents
0 commit comments