diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 626367029e33..a2c49652739f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -118,129 +118,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image -def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" - deprecate( - "prepare_mask_and_masked_image", - "0.30.0", - deprecation_message, - ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask is None: - raise ValueError("`mask_image` input cannot be undefined.") - - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - # n.b. ensure backwards compatibility as old function does not return image - if return_image: - return mask, masked_image, image - - return mask, masked_image - - class StableDiffusionControlNetInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index bc94911b7e55..c471efc3b48f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import PIL.Image import torch from packaging import version @@ -38,128 +37,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" - deprecate( - "prepare_mask_and_masked_image", - "0.30.0", - deprecation_message, - ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask is None: - raise ValueError("`mask_image` input cannot be undefined.") - - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - # n.b. ensure backwards compatibility as old function does not return image - if return_image: - return mask, masked_image, image - - return mask, masked_image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 631e309993b1..6fe9b4670ce8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -132,124 +132,6 @@ def mask_pil_to_torch(mask, height, width): return mask -def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - - # checkpoint. TOD(Yiyi) - need to clean this up later - deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" - deprecate( - "prepare_mask_and_masked_image", - "0.30.0", - deprecation_message, - ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask is None: - raise ValueError("`mask_image` input cannot be undefined.") - - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - mask = mask_pil_to_torch(mask, height, width) - - if image.ndim == 3: - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - # if image.min() < -1 or image.max() > 1: - # raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - mask = mask_pil_to_torch(mask, height, width) - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - if image.shape[1] == 4: - # images are in latent space and thus can't - # be masked set masked_image to None - # we assume that the checkpoint is not an inpainting - # checkpoint. TOD(Yiyi) - need to clean this up later - masked_image = None - else: - masked_image = image * (mask < 0.5) - - # n.b. ensure backwards compatibility as old function does not return image - if return_image: - return mask, masked_image, image - - return mask, masked_image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 737f8e90ac6b..dbf2ecb00f53 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -36,7 +36,6 @@ StableDiffusionInpaintPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, @@ -1105,530 +1104,3 @@ def test_inpaint_dpm(self): ) max_diff = np.abs(expected_image - image).max() assert max_diff < 1e-3 - - -class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): - def test_pil_inputs(self): - height, width = 32, 32 - im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) - im = Image.fromarray(im) - mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 - mask = Image.fromarray((mask * 255).astype(np.uint8)) - - t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True) - - self.assertTrue(isinstance(t_mask, torch.Tensor)) - self.assertTrue(isinstance(t_masked, torch.Tensor)) - self.assertTrue(isinstance(t_image, torch.Tensor)) - - self.assertEqual(t_mask.ndim, 4) - self.assertEqual(t_masked.ndim, 4) - self.assertEqual(t_image.ndim, 4) - - self.assertEqual(t_mask.shape, (1, 1, height, width)) - self.assertEqual(t_masked.shape, (1, 3, height, width)) - self.assertEqual(t_image.shape, (1, 3, height, width)) - - self.assertTrue(t_mask.dtype == torch.float32) - self.assertTrue(t_masked.dtype == torch.float32) - self.assertTrue(t_image.dtype == torch.float32) - - self.assertTrue(t_mask.min() >= 0.0) - self.assertTrue(t_mask.max() <= 1.0) - self.assertTrue(t_masked.min() >= -1.0) - self.assertTrue(t_masked.min() <= 1.0) - self.assertTrue(t_image.min() >= -1.0) - self.assertTrue(t_image.min() >= -1.0) - - self.assertTrue(t_mask.sum() > 0.0) - - def test_np_inputs(self): - height, width = 32, 32 - - im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) - im_pil = Image.fromarray(im_np) - mask_np = ( - np.random.randint( - 0, - 255, - ( - height, - width, - ), - dtype=np.uint8, - ) - > 127.5 - ) - mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image( - im_pil, mask_pil, height, width, return_image=True - ) - - self.assertTrue((t_mask_np == t_mask_pil).all()) - self.assertTrue((t_masked_np == t_masked_pil).all()) - self.assertTrue((t_image_np == t_image_pil).all()) - - def test_torch_3D_2D_inputs(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - im_np = im_tensor.numpy().transpose(1, 2, 0) - mask_np = mask_tensor.numpy() - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_3D_3D_inputs(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - 1, - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - im_np = im_tensor.numpy().transpose(1, 2, 0) - mask_np = mask_tensor.numpy()[0] - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_4D_2D_inputs(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 1, - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - im_np = im_tensor.numpy()[0].transpose(1, 2, 0) - mask_np = mask_tensor.numpy() - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_4D_3D_inputs(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 1, - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - 1, - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - im_np = im_tensor.numpy()[0].transpose(1, 2, 0) - mask_np = mask_tensor.numpy()[0] - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_4D_4D_inputs(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 1, - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - 1, - 1, - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - im_np = im_tensor.numpy()[0].transpose(1, 2, 0) - mask_np = mask_tensor.numpy()[0][0] - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( - im_np, mask_np, height, width, return_image=True - ) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_batch_4D_3D(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 2, - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - 2, - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - - im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] - mask_nps = [mask.numpy() for mask in mask_tensor] - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] - t_mask_np = torch.cat([n[0] for n in nps]) - t_masked_np = torch.cat([n[1] for n in nps]) - t_image_np = torch.cat([n[2] for n in nps]) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_torch_batch_4D_4D(self): - height, width = 32, 32 - - im_tensor = torch.randint( - 0, - 255, - ( - 2, - 3, - height, - width, - ), - dtype=torch.uint8, - ) - mask_tensor = ( - torch.randint( - 0, - 255, - ( - 2, - 1, - height, - width, - ), - dtype=torch.uint8, - ) - > 127.5 - ) - - im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] - mask_nps = [mask.numpy()[0] for mask in mask_tensor] - - t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True - ) - nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] - t_mask_np = torch.cat([n[0] for n in nps]) - t_masked_np = torch.cat([n[1] for n in nps]) - t_image_np = torch.cat([n[2] for n in nps]) - - self.assertTrue((t_mask_tensor == t_mask_np).all()) - self.assertTrue((t_masked_tensor == t_masked_np).all()) - self.assertTrue((t_image_tensor == t_image_np).all()) - - def test_shape_mismatch(self): - height, width = 32, 32 - - # test height and width - with self.assertRaises(AssertionError): - prepare_mask_and_masked_image( - torch.randn( - 3, - height, - width, - ), - torch.randn(64, 64), - height, - width, - return_image=True, - ) - # test batch dim - with self.assertRaises(AssertionError): - prepare_mask_and_masked_image( - torch.randn( - 2, - 3, - height, - width, - ), - torch.randn(4, 64, 64), - height, - width, - return_image=True, - ) - # test batch dim - with self.assertRaises(AssertionError): - prepare_mask_and_masked_image( - torch.randn( - 2, - 3, - height, - width, - ), - torch.randn(4, 1, 64, 64), - height, - width, - return_image=True, - ) - - def test_type_mismatch(self): - height, width = 32, 32 - - # test tensors-only - with self.assertRaises(TypeError): - prepare_mask_and_masked_image( - torch.rand( - 3, - height, - width, - ), - torch.rand( - 3, - height, - width, - ).numpy(), - height, - width, - return_image=True, - ) - # test tensors-only - with self.assertRaises(TypeError): - prepare_mask_and_masked_image( - torch.rand( - 3, - height, - width, - ).numpy(), - torch.rand( - 3, - height, - width, - ), - height, - width, - return_image=True, - ) - - def test_channels_first(self): - height, width = 32, 32 - - # test channels first for 3D tensors - with self.assertRaises(AssertionError): - prepare_mask_and_masked_image( - torch.rand(height, width, 3), - torch.rand( - 3, - height, - width, - ), - height, - width, - return_image=True, - ) - - def test_tensor_range(self): - height, width = 32, 32 - - # test im <= 1 - with self.assertRaises(ValueError): - prepare_mask_and_masked_image( - torch.ones( - 3, - height, - width, - ) - * 2, - torch.rand( - height, - width, - ), - height, - width, - return_image=True, - ) - # test im >= -1 - with self.assertRaises(ValueError): - prepare_mask_and_masked_image( - torch.ones( - 3, - height, - width, - ) - * (-2), - torch.rand( - height, - width, - ), - height, - width, - return_image=True, - ) - # test mask <= 1 - with self.assertRaises(ValueError): - prepare_mask_and_masked_image( - torch.rand( - 3, - height, - width, - ), - torch.ones( - height, - width, - ) - * 2, - height, - width, - return_image=True, - ) - # test mask >= 0 - with self.assertRaises(ValueError): - prepare_mask_and_masked_image( - torch.rand( - 3, - height, - width, - ), - torch.ones( - height, - width, - ) - * -1, - height, - width, - return_image=True, - )