diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index 7af88f627559..d100001f6349 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -13,33 +13,61 @@ # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch import PIL -from tqdm.auto import tqdm from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import RePaintScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging -def _preprocess_image(image: PIL.Image.Image): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) return image -def _preprocess_mask(mask: PIL.Image.Image): - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) +def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + w, h = mask[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) + mask = mask.astype(np.float32) / 255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) return mask @@ -54,8 +82,8 @@ def __init__(self, unet, scheduler): @torch.no_grad() def __call__( self, - original_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], + image: Union[torch.Tensor, PIL.Image.Image], + mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, jump_length: int = 10, @@ -63,10 +91,11 @@ def __call__( generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - original_image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor` or `PIL.Image.Image`): The original image to inpaint on. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): The mask_image where 0.0 values define which part of the original image to inpaint (change). @@ -97,12 +126,14 @@ def __call__( generated images. """ - if not isinstance(original_image, torch.FloatTensor): - original_image = _preprocess_image(original_image) - original_image = original_image.to(self.device) - if not isinstance(mask_image, torch.FloatTensor): - mask_image = _preprocess_mask(mask_image) - mask_image = mask_image.to(self.device) + message = "Please use `image` instead of `original_image`." + original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs) + original_image = original_image or image + + original_image = _preprocess_image(original_image) + original_image = original_image.to(device=self.device, dtype=self.unet.dtype) + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) # sample gaussian noise to begin the loop image = torch.randn( @@ -110,14 +141,14 @@ def __call__( generator=generator, device=self.device, ) - image = image.to(self.device) + image = image.to(device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 - for i, t in enumerate(tqdm(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index f41a41fd49dd..2d4fd8100ded 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -270,9 +270,13 @@ def step( # been observed. # 5. Add noise - noise = torch.randn( - model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device - ) + device = model_output.device + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + noise = noise.to(device) + else: + noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 variance = 0 @@ -305,7 +309,12 @@ def undo_step(self, sample, timestep, generator=None): for i in range(n): beta = self.betas[timestep + i] - noise = torch.randn(sample.shape, generator=generator, device=sample.device) + if sample.device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator) + noise = noise.to(sample.device) + else: + noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index d1ecd3c06ee4..889d54a23e99 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -21,10 +21,68 @@ from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device +from ...test_pipelines_common import PipelineTesterMixin + torch.backends.cuda.matmul.allow_tf32 = False +class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = RePaintPipeline + test_cpu_offload = False + + def get_dummy_components(self): + torch.manual_seed(0) + torch.manual_seed(0) + unet = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + scheduler = RePaintScheduler() + components = {"unet": unet, "scheduler": scheduler} + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32)) + image = torch.from_numpy(image).to(device=device, dtype=torch.float32) + mask = (image > 0).to(device=device, dtype=torch.float32) + inputs = { + "image": image, + "mask_image": mask, + "generator": generator, + "num_inference_steps": 5, + "eta": 0.0, + "jump_length": 2, + "jump_n_sample": 2, + "output_type": "numpy", + } + return inputs + + def test_repaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = RePaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + @slow @require_torch_gpu class RepaintPipelineIntegrationTests(unittest.TestCase):