From 939fa59aeffc8d8ec7aebd7645cb172b43f7cf57 Mon Sep 17 00:00:00 2001 From: anton- Date: Wed, 14 Dec 2022 11:32:55 +0100 Subject: [PATCH 1/5] add fast tests --- .../pipelines/repaint/pipeline_repaint.py | 10 ++--- tests/pipelines/repaint/test_repaint.py | 44 +++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index 7af88f627559..4009686b504a 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -54,8 +54,8 @@ def __init__(self, unet, scheduler): @torch.no_grad() def __call__( self, - original_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], + original_image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, jump_length: int = 10, @@ -97,10 +97,10 @@ def __call__( generated images. """ - if not isinstance(original_image, torch.FloatTensor): + if not isinstance(original_image, torch.Tensor): original_image = _preprocess_image(original_image) original_image = original_image.to(self.device) - if not isinstance(mask_image, torch.FloatTensor): + if not isinstance(mask_image, torch.Tensor): mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(self.device) @@ -117,7 +117,7 @@ def __call__( self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 - for i, t in enumerate(tqdm(self.scheduler.timesteps)): + for i, t in enumerate((self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index d1ecd3c06ee4..a18795c33ce7 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -22,9 +22,53 @@ from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device +from ...test_pipelines_common import PipelineTesterMixin + + torch.backends.cuda.matmul.allow_tf32 = False + +class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = RePaintPipeline + test_cpu_offload = False + + def get_dummy_components(self): + torch.manual_seed(0) + torch.manual_seed(0) + unet = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + scheduler = RePaintScheduler() + components = {"unet": unet, "scheduler": scheduler} + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32)) + image = torch.from_numpy(image).to(device=device, dtype=torch.float32) + mask = (image > 0).to(device=device, dtype=torch.float32) + inputs = { + "original_image": image, + "mask_image": mask, + "generator": generator, + "num_inference_steps": 5, + "eta": 0.0, + "jump_length": 2, + "jump_n_sample": 2, + "output_type": "numpy", + } + return inputs + @slow @require_torch_gpu class RepaintPipelineIntegrationTests(unittest.TestCase): From 79101eb0f2551290f151b0caad15eaf6fbe9808b Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 15 Dec 2022 13:50:20 +0100 Subject: [PATCH 2/5] better tests and fp16 --- .../pipelines/repaint/pipeline_repaint.py | 9 ++++----- src/diffusers/schedulers/scheduling_repaint.py | 17 +++++++++++++---- tests/pipelines/repaint/test_repaint.py | 18 ++++++++++++++++-- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index 4009686b504a..a37783b901f9 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -19,7 +19,6 @@ import torch import PIL -from tqdm.auto import tqdm from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -99,10 +98,10 @@ def __call__( if not isinstance(original_image, torch.Tensor): original_image = _preprocess_image(original_image) - original_image = original_image.to(self.device) + original_image = original_image.to(device=self.device, dtype=self.unet.dtype) if not isinstance(mask_image, torch.Tensor): mask_image = _preprocess_mask(mask_image) - mask_image = mask_image.to(self.device) + mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) # sample gaussian noise to begin the loop image = torch.randn( @@ -110,14 +109,14 @@ def __call__( generator=generator, device=self.device, ) - image = image.to(self.device) + image = image.to(device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 - for i, t in enumerate((self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index f41a41fd49dd..2d4fd8100ded 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -270,9 +270,13 @@ def step( # been observed. # 5. Add noise - noise = torch.randn( - model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device - ) + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + noise = noise.to(device) + else: + noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 variance = 0 @@ -305,7 +309,12 @@ def undo_step(self, sample, timestep, generator=None): for i in range(n): beta = self.betas[timestep + i] - noise = torch.randn(sample.shape, generator=generator, device=sample.device) + if sample.device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator) + noise = noise.to(sample.device) + else: + noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index a18795c33ce7..f3a06c9c0130 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -21,14 +21,12 @@ from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device - from ...test_pipelines_common import PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False - class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = RePaintPipeline test_cpu_offload = False @@ -69,6 +67,22 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def test_repaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = RePaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @slow @require_torch_gpu class RepaintPipelineIntegrationTests(unittest.TestCase): From fa0993def2c50108f579d3570f8ed2c2c8819ca1 Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 15 Dec 2022 14:11:21 +0100 Subject: [PATCH 3/5] batch fixes --- .../pipelines/repaint/pipeline_repaint.py | 56 +++++++++++++------ tests/pipelines/repaint/test_repaint.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index a37783b901f9..e75dea54a0e5 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -23,22 +23,39 @@ from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import RePaintScheduler +from ...utils import deprecate -def _preprocess_image(image: PIL.Image.Image): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 +def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) return image -def _preprocess_mask(mask: PIL.Image.Image): - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) +def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) return mask @@ -53,7 +70,7 @@ def __init__(self, unet, scheduler): @torch.no_grad() def __call__( self, - original_image: Union[torch.Tensor, PIL.Image.Image], + image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, @@ -62,10 +79,11 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - original_image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor` or `PIL.Image.Image`): The original image to inpaint on. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): The mask_image where 0.0 values define which part of the original image to inpaint (change). @@ -96,11 +114,13 @@ def __call__( generated images. """ - if not isinstance(original_image, torch.Tensor): - original_image = _preprocess_image(original_image) + message = "Please use `image` instead of `original_image`." + original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs) + original_image = original_image or image + + original_image = _preprocess_image(original_image) original_image = original_image.to(device=self.device, dtype=self.unet.dtype) - if not isinstance(mask_image, torch.Tensor): - mask_image = _preprocess_mask(mask_image) + mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) # sample gaussian noise to begin the loop diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index f3a06c9c0130..889d54a23e99 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -56,7 +56,7 @@ def get_dummy_inputs(self, device, seed=0): image = torch.from_numpy(image).to(device=device, dtype=torch.float32) mask = (image > 0).to(device=device, dtype=torch.float32) inputs = { - "original_image": image, + "image": image, "mask_image": mask, "generator": generator, "num_inference_steps": 5, From 4f567efed0a387783fb8f70ef2d9de4ca0426dc9 Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 15 Dec 2022 17:46:35 +0100 Subject: [PATCH 4/5] Reuse preprocessing --- .../pipelines/repaint/pipeline_repaint.py | 24 ++++++++++++++----- tests/pipelines/repaint/test_repaint.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index e75dea54a0e5..d100001f6349 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -23,9 +23,13 @@ from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import RePaintScheduler -from ...utils import deprecate +from ...utils import PIL_INTERPOLATION, deprecate, logging +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(image, torch.Tensor): return image @@ -33,9 +37,15 @@ def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): image = [image] if isinstance(image[0], PIL.Image.Image): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image @@ -48,9 +58,11 @@ def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): mask = [mask] if isinstance(mask[0], PIL.Image.Image): - mask = np.array(mask.convert("L")) + w, h = mask[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index 889d54a23e99..b30283e57495 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -83,7 +83,7 @@ def test_repaint(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 -@slow +# @slow @require_torch_gpu class RepaintPipelineIntegrationTests(unittest.TestCase): def test_celebahq(self): From dec7bbab1fcb7b9991019f53b8c1f1453e992378 Mon Sep 17 00:00:00 2001 From: anton- Date: Thu, 15 Dec 2022 17:46:58 +0100 Subject: [PATCH 5/5] quickfix --- tests/pipelines/repaint/test_repaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index b30283e57495..889d54a23e99 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -83,7 +83,7 @@ def test_repaint(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 -# @slow +@slow @require_torch_gpu class RepaintPipelineIntegrationTests(unittest.TestCase): def test_celebahq(self):