Skip to content

RePaint fast tests and API conforming #1701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions src/diffusers/pipelines/repaint/pipeline_repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,61 @@
# limitations under the License.


from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

import PIL
from tqdm.auto import tqdm

from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging


def _preprocess_image(image: PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a Copied from here from the img2img pipeline? so that it stays in sync?

if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32

image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image


def _preprocess_mask(mask: PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
if isinstance(mask, torch.Tensor):
return mask
elif isinstance(mask, PIL.Image.Image):
mask = [mask]

if isinstance(mask[0], PIL.Image.Image):
w, h = mask[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
mask = np.concatenate(mask, axis=0)
mask = mask.astype(np.float32) / 255.0
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
elif isinstance(mask[0], torch.Tensor):
mask = torch.cat(mask, dim=0)
return mask


Expand All @@ -54,19 +82,20 @@ def __init__(self, unet, scheduler):
@torch.no_grad()
def __call__(
self,
original_image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
image: Union[torch.Tensor, PIL.Image.Image],
mask_image: Union[torch.Tensor, PIL.Image.Image],
num_inference_steps: int = 250,
eta: float = 0.0,
jump_length: int = 10,
jump_n_sample: int = 10,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor` or `PIL.Image.Image`):
The original image to inpaint on.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
The mask_image where 0.0 values define which part of the original image to inpaint (change).
Expand Down Expand Up @@ -97,27 +126,29 @@ def __call__(
generated images.
"""

if not isinstance(original_image, torch.FloatTensor):
original_image = _preprocess_image(original_image)
original_image = original_image.to(self.device)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = _preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
message = "Please use `image` instead of `original_image`."
original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
original_image = original_image or image

original_image = _preprocess_image(original_image)
original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
mask_image = _preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)

# sample gaussian noise to begin the loop
image = torch.randn(
original_image.shape,
generator=generator,
device=self.device,
)
image = image.to(self.device)
image = image.to(device=self.device, dtype=self.unet.dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
self.scheduler.eta = eta

t_last = self.scheduler.timesteps[0] + 1
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
if t < t_last:
# predict the noise residual
model_output = self.unet(image, t).sample
Expand Down
17 changes: 13 additions & 4 deletions src/diffusers/schedulers/scheduling_repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,13 @@ def step(
# been observed.

# 5. Add noise
noise = torch.randn(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
)
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
noise = noise.to(device)
else:
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5

variance = 0
Expand Down Expand Up @@ -305,7 +309,12 @@ def undo_step(self, sample, timestep, generator=None):

for i in range(n):
beta = self.betas[timestep + i]
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
if sample.device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
noise = noise.to(sample.device)
else:
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)

# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
Expand Down
58 changes: 58 additions & 0 deletions tests/pipelines/repaint/test_repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,68 @@
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device

from ...test_pipelines_common import PipelineTesterMixin


torch.backends.cuda.matmul.allow_tf32 = False


class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = RePaintPipeline
test_cpu_offload = False

def get_dummy_components(self):
torch.manual_seed(0)
torch.manual_seed(0)
unet = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = RePaintScheduler()
components = {"unet": unet, "scheduler": scheduler}
return components

def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32))
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
mask = (image > 0).to(device=device, dtype=torch.float32)
inputs = {
"image": image,
"mask_image": mask,
"generator": generator,
"num_inference_steps": 5,
"eta": 0.0,
"jump_length": 2,
"jump_n_sample": 2,
"output_type": "numpy",
}
return inputs

def test_repaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = RePaintPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3


@slow
@require_torch_gpu
class RepaintPipelineIntegrationTests(unittest.TestCase):
Expand Down