Skip to content

Commit e4ffadc

Browse files
Merge branch 'main' of https://github.com/huggingface/diffusers into main
2 parents ec7c8d3 + c9b3463 commit e4ffadc

File tree

9 files changed

+782
-477
lines changed

9 files changed

+782
-477
lines changed

docs/source/training/text2image.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
1313

1414
# Stable Diffusion text-to-image fine-tuning
1515

16-
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script shows how to fine-tune the stable diffusion model on your own dataset.
16+
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
1717

1818
<Tip warning={true}>
1919

src/diffusers/models/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
557557
return hidden_states
558558

559559
def _memory_efficient_attention_xformers(self, query, key, value):
560+
query = query.contiguous()
561+
key = key.contiguous()
562+
value = value.contiguous()
560563
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
561564
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
562565
return hidden_states

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 168 additions & 105 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 96 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,73 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
298298

299299
return text_embeddings
300300

301+
def run_safety_checker(self, image, device, dtype):
302+
if self.safety_checker is not None:
303+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
304+
image, has_nsfw_concept = self.safety_checker(
305+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
306+
)
307+
else:
308+
has_nsfw_concept = None
309+
return image, has_nsfw_concept
310+
311+
def decode_latents(self, latents):
312+
latents = 1 / 0.18215 * latents
313+
image = self.vae.decode(latents).sample
314+
image = (image / 2 + 0.5).clamp(0, 1)
315+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
316+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
317+
return image
318+
319+
def prepare_extra_step_kwargs(self, generator, eta):
320+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
321+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
322+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
323+
# and should be between [0, 1]
324+
325+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
326+
extra_step_kwargs = {}
327+
if accepts_eta:
328+
extra_step_kwargs["eta"] = eta
329+
330+
# check if the scheduler accepts generator
331+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
332+
if accepts_generator:
333+
extra_step_kwargs["generator"] = generator
334+
return extra_step_kwargs
335+
336+
def check_inputs(self, prompt, height, width, callback_steps):
337+
if not isinstance(prompt, str) and not isinstance(prompt, list):
338+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
339+
340+
if height % 8 != 0 or width % 8 != 0:
341+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
342+
343+
if (callback_steps is None) or (
344+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
345+
):
346+
raise ValueError(
347+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
348+
f" {type(callback_steps)}."
349+
)
350+
351+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
352+
shape = (batch_size, num_channels_latents, height // 8, width // 8)
353+
if latents is None:
354+
if device.type == "mps":
355+
# randn does not work reproducibly on mps
356+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
357+
else:
358+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
359+
else:
360+
if latents.shape != shape:
361+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
362+
latents = latents.to(device)
363+
364+
# scale the initial noise by the standard deviation required by the scheduler
365+
latents = latents * self.scheduler.init_noise_sigma
366+
return latents
367+
301368
@torch.no_grad()
302369
def __call__(
303370
self,
@@ -371,75 +438,45 @@ def __call__(
371438
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
372439
(nsfw) content, according to the `safety_checker`.
373440
"""
374-
if isinstance(prompt, str):
375-
batch_size = 1
376-
elif isinstance(prompt, list):
377-
batch_size = len(prompt)
378-
else:
379-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
380441

381-
if height % 8 != 0 or width % 8 != 0:
382-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
383-
384-
if (callback_steps is None) or (
385-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
386-
):
387-
raise ValueError(
388-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
389-
f" {type(callback_steps)}."
390-
)
442+
# 1. Check inputs. Raise error if not correct
443+
self.check_inputs(prompt, height, width, callback_steps)
391444

445+
# 2. Define call parameters
446+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
392447
device = self._execution_device
393-
394448
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
395449
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
396450
# corresponds to doing no classifier free guidance.
397451
do_classifier_free_guidance = guidance_scale > 1.0
398452

453+
# 3. Encode input prompt
399454
text_embeddings = self._encode_prompt(
400455
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
401456
)
402457

403-
# Unlike in other pipelines, latents need to be generated in the target device
404-
# for 1-to-1 results reproducibility with the CompVis implementation.
405-
# However this currently doesn't work in `mps`.
406-
407-
# get the initial random noise unless the user supplied it
408-
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
409-
latents_dtype = text_embeddings.dtype
410-
if latents is None:
411-
if device.type == "mps":
412-
# randn does not work reproducibly on mps
413-
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
414-
else:
415-
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
416-
else:
417-
if latents.shape != latents_shape:
418-
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
419-
latents = latents.to(device)
420-
421-
# set timesteps and move to the correct device
458+
# 4. Prepare timesteps
422459
self.scheduler.set_timesteps(num_inference_steps, device=device)
423-
timesteps_tensor = self.scheduler.timesteps
424-
425-
# scale the initial noise by the standard deviation required by the scheduler
426-
latents = latents * self.scheduler.init_noise_sigma
427-
428-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
429-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
430-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
431-
# and should be between [0, 1]
432-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
433-
extra_step_kwargs = {}
434-
if accepts_eta:
435-
extra_step_kwargs["eta"] = eta
460+
timesteps = self.scheduler.timesteps
461+
462+
# 5. Prepare latent variables
463+
num_channels_latents = self.unet.in_channels
464+
latents = self.prepare_latents(
465+
batch_size * num_images_per_prompt,
466+
num_channels_latents,
467+
height,
468+
width,
469+
text_embeddings.dtype,
470+
device,
471+
generator,
472+
latents,
473+
)
436474

437-
# check if the scheduler accepts generator
438-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
439-
if accepts_generator:
440-
extra_step_kwargs["generator"] = generator
475+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
476+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
441477

442-
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
478+
# 7. Denoising loop
479+
for i, t in enumerate(self.progress_bar(timesteps)):
443480
# expand the latents if we are doing classifier free guidance
444481
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
445482
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -459,22 +496,13 @@ def __call__(
459496
if callback is not None and i % callback_steps == 0:
460497
callback(i, t, latents)
461498

462-
latents = 1 / 0.18215 * latents
463-
image = self.vae.decode(latents).sample
464-
465-
image = (image / 2 + 0.5).clamp(0, 1)
466-
467-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
468-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
499+
# 8. Post-processing
500+
image = self.decode_latents(latents)
469501

470-
if self.safety_checker is not None:
471-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
472-
image, has_nsfw_concept = self.safety_checker(
473-
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
474-
)
475-
else:
476-
has_nsfw_concept = None
502+
# 9. Run safety checker
503+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
477504

505+
# 10. Convert to PIL
478506
if output_type == "pil":
479507
image = self.numpy_to_pil(image)
480508

0 commit comments

Comments
 (0)