Skip to content

Postprocessing refactor all others #3337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union

import torch
Expand All @@ -22,6 +23,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -174,6 +176,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_vae_slicing(self):
Expand Down Expand Up @@ -426,16 +429,27 @@ def _encode_prompt(
return prompt_embeds

def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -700,24 +714,19 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if output_type == "latent":
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)

# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

# 10. Convert to PIL
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
# 8. Post-processing
image = self.decode_latents(latents)
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device
image = [np.array(i).astype(np.float32) / 255.0 for i in image]

image = np.stack(image, axis=0) # to np
torch.from_numpy(image.transpose(0, 3, 1, 2))
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

elif isinstance(image[0], np.ndarray):
image = np.stack(image, axis=0) # to np
if image.ndim == 5:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device
image = [np.array(i).astype(np.float32) / 255.0 for i in image]

image = np.stack(image, axis=0) # to np
torch.from_numpy(image.transpose(0, 3, 1, 2))
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
elif isinstance(image[0], np.ndarray):
image = np.stack(image, axis=0) # to np
if image.ndim == 5:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -22,6 +23,7 @@

from diffusers.utils import is_accelerate_available

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging, randn_tensor
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -226,13 +229,17 @@ def _execution_device(self):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand All @@ -255,6 +262,11 @@ def prepare_extra_step_kwargs(self, generator, eta):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -560,15 +572,19 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 11. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None

# 12. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

# 13. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import inspect
import warnings
from itertools import repeat
from typing import Callable, List, Optional, Union

import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -129,10 +131,31 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -681,20 +704,19 @@ def __call__(
callback(i, t, latents)

# 8. Post-processing
image = self.decode_latents(latents)

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import warnings
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -24,6 +25,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
Expand Down Expand Up @@ -220,6 +222,8 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
Expand Down Expand Up @@ -504,17 +508,26 @@ def prepare_extra_step_kwargs(self, generator, eta):

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
Expand Down Expand Up @@ -770,14 +783,19 @@ def __call__(
callback(i, t, latents)

# 9. Post-processing
image = self.decode_latents(latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None

# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
Loading