diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py new file mode 100644 index 000000000000..de6543800b2d --- /dev/null +++ b/src/diffusers/image_processor.py @@ -0,0 +1,177 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Union + +import numpy as np +import PIL +import torch +from PIL import Image + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME, PIL_INTERPOLATION + + +class VaeImageProcessor(ConfigMixin): + """ + Image Processor for VAE + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this + factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1] + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = True, + ): + super().__init__() + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def numpy_to_pt(images): + """ + Convert a numpy image to a pytorch tensor + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def pt_to_numpy(images): + """ + Convert a numpy image to a pytorch tensor + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def normalize(images): + """ + Normalize an image array to [-1,1] + """ + return 2.0 * images - 1.0 + + def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: + """ + Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` + """ + w, h = images.size + w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor + images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) + return images + + def preprocess( + self, + image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + ) -> torch.Tensor: + """ + Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + if isinstance(image, supported_formats): + image = [image] + elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(image[0], PIL.Image.Image): + if self.do_resize: + image = [self.resize(i) for i in image] + image = [np.array(i).astype(np.float32) / 255.0 for i in image] + image = np.stack(image, axis=0) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = self.numpy_to_pt(image) + _, _, height, width = image.shape + if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + raise ValueError( + f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + ) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + _, _, height, width = image.shape + if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): + raise ValueError( + f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" + f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + ) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.do_normalize + if image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + image = self.normalize(image) + + return image + + def postprocess( + self, + image, + output_type: str = "pil", + ): + if isinstance(image, torch.Tensor) and output_type == "pt": + return image + + image = self.pt_to_numpy(image) + + if output_type == "np": + return image + elif output_type == "pil": + return self.numpy_to_pil(image) + else: + raise ValueError(f"Unsupported output_type {output_type}.") diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 1e7872e3b081..05138c86f246 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,6 +24,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring @@ -192,7 +193,6 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -203,7 +203,11 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config( + requires_safety_checker=requires_safety_checker, + ) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -415,21 +419,17 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): @@ -663,7 +663,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -703,15 +703,26 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + image = latents + has_nsfw_concept = None + image = self.decode_latents(latents) - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if self.safety_checker is not None: + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + has_nsfw_concept = False - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 172ab15a757e..8b3a7944def1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -22,6 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -119,7 +120,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): """ _optional_components = ["safety_checker", "feature_extractor"] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, @@ -196,7 +196,6 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -207,7 +206,11 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config( + requires_safety_checker=requires_safety_checker, + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): @@ -422,24 +425,18 @@ def _encode_prompt( return prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -674,7 +671,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -714,15 +711,26 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + image = latents + has_nsfw_concept = None + image = self.decode_latents(latents) - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if self.safety_checker is not None: + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + has_nsfw_concept = False - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index d2745115af1c..939632943405 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -21,7 +21,13 @@ import torch from transformers import XLMRobertaTokenizer -from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AltDiffusionImg2ImgPipeline, + AutoencoderKL, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( RobertaSeriesConfig, RobertaSeriesModelWithTransformation, @@ -128,6 +134,7 @@ def test_stable_diffusion_img2img_default_case(self): safety_checker=None, feature_extractor=self.dummy_extractor, ) + alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe = alt_pipe.to(device) alt_pipe.set_progress_bar_config(disable=None) @@ -191,6 +198,7 @@ def test_stable_diffusion_img2img_fp16(self): safety_checker=None, feature_extractor=self.dummy_extractor, ) + alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe = alt_pipe.to(torch_device) alt_pipe.set_progress_bar_config(disable=None) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 77dfa9be1d1e..e27f83fc04fe 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -30,6 +30,7 @@ StableDiffusionImg2ImgPipeline, UNet2DConditionModel, ) +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps @@ -94,19 +95,33 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0): + def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) + + if input_image_type == "pt": + input_image = image + elif input_image_type == "np": + input_image = image.cpu().numpy().transpose(0, 2, 3, 1) + elif input_image_type == "pil": + input_image = image.cpu().numpy().transpose(0, 2, 3, 1) + input_image = VaeImageProcessor.numpy_to_pil(input_image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}.") + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"unsupported output_type {output_type}") + inputs = { "prompt": "A painting of a squirrel eating a burger", - "image": image, + "image": input_image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": output_type, } return inputs @@ -114,6 +129,7 @@ def test_stable_diffusion_img2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -130,6 +146,7 @@ def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -148,6 +165,7 @@ def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -169,6 +187,7 @@ def test_stable_diffusion_img2img_k_lms(self): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -197,6 +216,36 @@ def test_save_load_optional_components(self): def test_attention_slicing_forward_pass(self): return super().test_attention_slicing_forward_pass() + @skip_mps + def test_pt_np_pil_outputs_equivalent(self): + device = "cpu" + components = self.get_dummy_components() + sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0] + output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0] + output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0] + + assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4 + assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4 + + @skip_mps + def test_image_types_consistent(self): + device = "cpu" + components = self.get_dummy_components() + sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0] + output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0] + output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0] + + assert np.abs(output_pt - output_np).max() <= 1e-4 + assert np.abs(output_pil - output_np).max() <= 1e-2 + @slow @require_torch_gpu @@ -219,7 +268,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0 "num_inference_steps": 3, "strength": 0.75, "guidance_scale": 7.5, - "output_type": "numpy", + "output_type": "np", } return inputs @@ -426,7 +475,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0 "num_inference_steps": 50, "strength": 0.75, "guidance_scale": 7.5, - "output_type": "numpy", + "output_type": "np", } return inputs diff --git a/tests/test_image_processor.py b/tests/test_image_processor.py new file mode 100644 index 000000000000..4f0e2c5aecfd --- /dev/null +++ b/tests/test_image_processor.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL +import torch + +from diffusers.image_processor import VaeImageProcessor + + +class ImageProcessorTest(unittest.TestCase): + @property + def dummy_sample(self): + batch_size = 1 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + + return sample + + def to_np(self, image): + if isinstance(image[0], PIL.Image.Image): + return np.stack([np.array(i) for i in image], axis=0) + elif isinstance(image, torch.Tensor): + return image.cpu().numpy().transpose(0, 2, 3, 1) + return image + + def test_vae_image_processor_pt(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + + input_pt = self.dummy_sample + input_np = self.to_np(input_pt) + + for output_type in ["pt", "np", "pil"]: + out = image_processor.postprocess( + image_processor.preprocess(input_pt), + output_type=output_type, + ) + out_np = self.to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + assert ( + np.abs(in_np - out_np).max() < 1e-6 + ), f"decoded output does not match input for output_type {output_type}" + + def test_vae_image_processor_np(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) + + for output_type in ["pt", "np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) + + out_np = self.to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + assert ( + np.abs(in_np - out_np).max() < 1e-6 + ), f"decoded output does not match input for output_type {output_type}" + + def test_vae_image_processor_pil(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + + input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) + input_pil = image_processor.numpy_to_pil(input_np) + + for output_type in ["pt", "np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) + for i, o in zip(input_pil, out): + in_np = np.array(i) + out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round() + assert ( + np.abs(in_np - out_np).max() < 1e-6 + ), f"decoded output does not match input for output_type {output_type}" + + def test_preprocess_input_3d(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + + input_pt_4d = self.dummy_sample + input_pt_3d = input_pt_4d.squeeze(0) + + out_pt_4d = image_processor.postprocess( + image_processor.preprocess(input_pt_4d), + output_type="np", + ) + out_pt_3d = image_processor.postprocess( + image_processor.preprocess(input_pt_3d), + output_type="np", + ) + + input_np_4d = self.to_np(self.dummy_sample) + input_np_3d = input_np_4d.squeeze(0) + + out_np_4d = image_processor.postprocess( + image_processor.preprocess(input_np_4d), + output_type="np", + ) + out_np_3d = image_processor.postprocess( + image_processor.preprocess(input_np_3d), + output_type="np", + ) + + assert np.abs(out_pt_4d - out_pt_3d).max() < 1e-6 + assert np.abs(out_np_4d - out_np_3d).max() < 1e-6 + + def test_preprocess_input_list(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + + input_pt_4d = self.dummy_sample + input_pt_list = list(input_pt_4d) + + out_pt_4d = image_processor.postprocess( + image_processor.preprocess(input_pt_4d), + output_type="np", + ) + + out_pt_list = image_processor.postprocess( + image_processor.preprocess(input_pt_list), + output_type="np", + ) + + input_np_4d = self.to_np(self.dummy_sample) + list(input_np_4d) + + out_np_4d = image_processor.postprocess( + image_processor.preprocess(input_pt_4d), + output_type="np", + ) + + out_np_list = image_processor.postprocess( + image_processor.preprocess(input_pt_list), + output_type="np", + ) + + assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6 + assert np.abs(out_np_4d - out_np_list).max() < 1e-6