File tree Expand file tree Collapse file tree 7 files changed +18
-14
lines changed
src/diffusers/pipelines/stable_diffusion Expand file tree Collapse file tree 7 files changed +18
-14
lines changed Original file line number Diff line number Diff line change @@ -278,7 +278,7 @@ def __call__(
278
278
if do_classifier_free_guidance :
279
279
uncond_tokens : List [str ]
280
280
if negative_prompt is None :
281
- uncond_tokens = ["" ]
281
+ uncond_tokens = ["" ] * batch_size
282
282
elif type (prompt ) is not type (negative_prompt ):
283
283
raise TypeError (
284
284
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -307,7 +307,7 @@ def __call__(
307
307
308
308
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
309
309
seq_len = uncond_embeddings .shape [1 ]
310
- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
310
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
311
311
uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
312
312
313
313
# For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -148,7 +148,7 @@ def __call__(
148
148
if do_classifier_free_guidance :
149
149
uncond_tokens : List [str ]
150
150
if negative_prompt is None :
151
- uncond_tokens = ["" ]
151
+ uncond_tokens = ["" ] * batch_size
152
152
elif type (prompt ) is not type (negative_prompt ):
153
153
raise TypeError (
154
154
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -177,7 +177,7 @@ def __call__(
177
177
178
178
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
179
179
seq_len = uncond_embeddings .shape [1 ]
180
- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
180
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
181
181
uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
182
182
183
183
# For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -295,7 +295,7 @@ def __call__(
295
295
if do_classifier_free_guidance :
296
296
uncond_tokens : List [str ]
297
297
if negative_prompt is None :
298
- uncond_tokens = ["" ]
298
+ uncond_tokens = ["" ] * batch_size
299
299
elif type (prompt ) is not type (negative_prompt ):
300
300
raise TypeError (
301
301
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -324,7 +324,7 @@ def __call__(
324
324
325
325
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
326
326
seq_len = uncond_embeddings .shape [1 ]
327
- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
327
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
328
328
uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
329
329
330
330
# For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -297,7 +297,7 @@ def __call__(
297
297
if do_classifier_free_guidance :
298
298
uncond_tokens : List [str ]
299
299
if negative_prompt is None :
300
- uncond_tokens = ["" ]
300
+ uncond_tokens = ["" ] * batch_size
301
301
elif type (prompt ) is not type (negative_prompt ):
302
302
raise TypeError (
303
303
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -326,7 +326,7 @@ def __call__(
326
326
327
327
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328
328
seq_len = uncond_embeddings .shape [1 ]
329
- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
329
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
330
330
uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
331
331
332
332
# For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -295,7 +295,7 @@ def __call__(
295
295
if do_classifier_free_guidance :
296
296
uncond_tokens : List [str ]
297
297
if negative_prompt is None :
298
- uncond_tokens = ["" ]
298
+ uncond_tokens = ["" ] * batch_size
299
299
elif type (prompt ) is not type (negative_prompt ):
300
300
raise TypeError (
301
301
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -319,7 +319,9 @@ def __call__(
319
319
uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
320
320
321
321
# duplicate unconditional embeddings for each generation per prompt
322
- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
322
+ seq_len = uncond_embeddings .shape [1 ]
323
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
324
+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
323
325
324
326
# For classifier free guidance, we need to do two forward passes.
325
327
# Here we concatenate the unconditional and text embeddings into a single batch
Original file line number Diff line number Diff line change @@ -302,7 +302,7 @@ def __call__(
302
302
if do_classifier_free_guidance :
303
303
uncond_tokens : List [str ]
304
304
if negative_prompt is None :
305
- uncond_tokens = ["" ]
305
+ uncond_tokens = ["" ] * batch_size
306
306
elif type (prompt ) is not type (negative_prompt ):
307
307
raise TypeError (
308
308
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -331,7 +331,7 @@ def __call__(
331
331
332
332
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
333
333
seq_len = uncond_embeddings .shape [1 ]
334
- uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
334
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
335
335
uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
336
336
337
337
# For classifier free guidance, we need to do two forward passes.
Original file line number Diff line number Diff line change @@ -284,7 +284,7 @@ def __call__(
284
284
if do_classifier_free_guidance :
285
285
uncond_tokens : List [str ]
286
286
if negative_prompt is None :
287
- uncond_tokens = ["" ]
287
+ uncond_tokens = ["" ] * batch_size
288
288
elif type (prompt ) is not type (negative_prompt ):
289
289
raise TypeError (
290
290
f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
@@ -312,7 +312,9 @@ def __call__(
312
312
uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
313
313
314
314
# duplicate unconditional embeddings for each generation per prompt
315
- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
315
+ seq_len = uncond_embeddings .shape [1 ]
316
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
317
+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
316
318
317
319
# For classifier free guidance, we need to do two forward passes.
318
320
# Here we concatenate the unconditional and text embeddings into a single batch
You can’t perform that action at this time.
0 commit comments