Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
Expand All @@ -221,18 +227,22 @@ def __call__(
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)

# adding noise to the masked areas depending on strength
rand_latents = torch.randn(
init_latents.shape,
generator=generator,
device=self.device,
)
init_latents_noised = init_latents * mask + rand_latents * (1 - mask)
init_latents = init_latents * (1 - strength) + init_latents_noised * strength

# multiply by scale_factor
init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ def test_stable_diffusion_inpaint(self):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
expected_slice = np.array(
[0.4893303, 0.5381786, 0.46649122, 0.62859786, 0.53987336, 0.39735478, 0.5483682, 0.59601367, 0.5178648]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

Expand Down