From 49d8702bebd61453276db2e311f96960ca15bc66 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Thu, 4 May 2023 10:22:46 +0100 Subject: [PATCH 1/3] StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. --- .../pipeline_stable_diffusion_inpaint.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 859a34677317..adef00352eee 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -34,7 +34,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask): +def prepare_mask_and_masked_image(image, mask, height, width): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -62,6 +62,13 @@ def prepare_mask_and_masked_image(image, mask): tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") @@ -109,8 +116,9 @@ def prepare_mask_and_masked_image(image, mask): # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -124,6 +132,7 @@ def prepare_mask_and_masked_image(image, mask): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): @@ -787,12 +796,6 @@ def __call__( negative_prompt_embeds, ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask_image is None: - raise ValueError("`mask_image` input cannot be undefined.") - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -818,8 +821,8 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) From c4f32109954f64fb820bf6a44cb4f601d456413c Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Thu, 4 May 2023 11:00:25 +0100 Subject: [PATCH 2/3] Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution --- .../test_stable_diffusion_inpaint.py | 119 +++++++++++------- 1 file changed, 72 insertions(+), 47 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 20977c346ecc..b6293bc46acb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -397,12 +397,13 @@ def test_inpaint_dpm(self): class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): def test_pil_inputs(self): - im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im = Image.fromarray(im) - mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) @@ -410,8 +411,8 @@ def test_pil_inputs(self): self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_masked.ndim, 4) - self.assertEqual(t_mask.shape, (1, 1, 32, 32)) - self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + self.assertEqual(t_mask.shape, (1, 1, height, width)) + self.assertEqual(t_masked.shape, (1, 3, height, width)) self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32) @@ -424,86 +425,100 @@ def test_pil_inputs(self): self.assertTrue(t_mask.sum() > 0.0) def test_np_inputs(self): - im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + + im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im_pil = Image.fromarray(im_np) - mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5 mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) - t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) def test_torch_3D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_3D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_4D_4D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5 im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_batch_4D_3D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5 im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -511,14 +526,16 @@ def test_torch_batch_4D_3D(self): self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_torch_batch_4D_4D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5 im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width) + nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) @@ -526,39 +543,47 @@ def test_torch_batch_4D_4D(self): self.assertTrue((t_masked_tensor == t_masked_np).all()) def test_shape_mismatch(self): + height, width = 32, 32 + # test height and width with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width) def test_type_mismatch(self): + height, width = 32, 32 + # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width) # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width) def test_channels_first(self): + height, width = 32, 32 + # test channels first for 3D tensors with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width) def test_tensor_range(self): + height, width = 32, 32 + # test im <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width) # test im >= -1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width) # test mask <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width) # test mask >= 0 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) + prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width) From c79de88e8b5877bc1edcf4d0eed943e4672b6981 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Thu, 4 May 2023 18:48:43 +0100 Subject: [PATCH 3/3] Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --- .../test_stable_diffusion_inpaint.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index b6293bc46acb..097093f2427e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -300,6 +300,25 @@ def test_inpaint_compile(self): assert np.abs(expected_slice - image_slice).max() < 1e-4 assert np.abs(expected_slice - image_slice).max() < 1e-3 + def test_stable_diffusion_inpaint_pil_input_resolution_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input image to a random size (one that would cause a tensor mismatch error) + inputs['image'] = inputs['image'].resize((127,127)) + inputs['mask_image'] = inputs['mask_image'].resize((127,127)) + inputs['height'] = 128 + inputs['width'] = 128 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, inputs['height'], inputs['width'], 3) + @nightly @require_torch_gpu