Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch

from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -106,6 +107,43 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

@torch.no_grad()
def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray:
r"""
Scale and decode the latent representations into images using the VAE.

Args:
latents (`torch.FloatTensor`):
Latent representations to decode into images.

Returns:
`np.ndarray`: Decoded images.
"""
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image

@torch.no_grad()
def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]:
r"""
Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be
raised and a black image will be returned instead.

Args:
image (`np.ndarray`):
Images to run the safety checker on.

Returns:
`Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the
safety checker. The second element is a boolean array indicating whether the images contain NSFW content.
"""
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
return image, has_nsfw_concept

@torch.no_grad()
def __call__(
self,
Expand All @@ -119,6 +157,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Expand Down Expand Up @@ -156,6 +196,13 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: np.ndarray, latents:
torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand Down Expand Up @@ -187,6 +234,14 @@ def __call__(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_input = self.tokenizer(
prompt,
Expand Down Expand Up @@ -270,16 +325,13 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = self.decode_latents(latents)

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -118,6 +118,43 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

@torch.no_grad()
def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray:
r"""
Scale and decode the latent representations into images using the VAE.

Args:
latents (`torch.FloatTensor`):
Latent representations to decode into images.

Returns:
`np.ndarray`: Decoded images.
"""
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image

@torch.no_grad()
def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]:
r"""
Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be
raised and a black image will be returned instead.

Args:
image (`np.ndarray`):
Images to run the safety checker on.

Returns:
`Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the
safety checker. The second element is a boolean array indicating whether the images contain NSFW content.
"""
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
return image, has_nsfw_concept

@torch.no_grad()
def __call__(
self,
Expand All @@ -130,6 +167,9 @@ def __call__(
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -167,6 +207,13 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: np.ndarray, latents:
torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -185,6 +232,14 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -254,6 +309,7 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
t_index = t_start + i

Expand All @@ -280,16 +336,13 @@ def __call__(
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = self.decode_latents(latents)

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import warnings
from typing import List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -137,6 +137,43 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slicing(None)

@torch.no_grad()
def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray:
r"""
Scale and decode the latent representations into images using the VAE.

Args:
latents (`torch.FloatTensor`):
Latent representations to decode into images.

Returns:
`np.ndarray`: Decoded images.
"""
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image

@torch.no_grad()
def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]:
r"""
Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be
raised and a black image will be returned instead.

Args:
image (`np.ndarray`):
Images to run the safety checker on.

Returns:
`Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the
safety checker. The second element is a boolean array indicating whether the images contain NSFW content.
"""
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
return image, has_nsfw_concept

@torch.no_grad()
def __call__(
self,
Expand All @@ -150,6 +187,9 @@ def __call__(
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, np.ndarray, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -191,6 +231,13 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: np.ndarray, latents:
torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
Expand All @@ -209,6 +256,14 @@ def __call__(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -290,7 +345,9 @@ def __call__(
extra_step_kwargs["eta"] = eta

latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)

for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
Expand Down Expand Up @@ -320,16 +377,13 @@ def __call__(

latents = (init_latents_proper * mask) + (latents * (1 - mask))

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = self.decode_latents(latents)

# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Loading