Skip to content
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
4b15eb4
add warning and add back copy from for img2img
May 25, 2023
7cd26ef
style
May 25, 2023
d58276c
fix copies
May 25, 2023
4c75754
add import warning
May 25, 2023
db25d47
update
May 25, 2023
f779977
handle image as latents
May 26, 2023
c9ef024
add a test for latents input
May 26, 2023
bc66a2a
add
May 26, 2023
bf0331d
copies
May 26, 2023
5c597c5
Merge remote-tracking branch 'origin/main' into vae-preprocessor
May 26, 2023
91423d7
add
May 26, 2023
19747b7
add
May 26, 2023
96a2a05
refactor img2img test
May 27, 2023
f7be498
refactor cycle diffusion
May 27, 2023
df73ada
refactor depth pipeline preprocess + handle np input for prepare_dept…
May 28, 2023
eb23e59
refactor latent tester
May 28, 2023
1655b1b
pix2pix0
May 29, 2023
076afc8
add a pt_np_pil input tests to pix2pix zero
May 30, 2023
706a53e
refactor postprocess for pix2pix0 invert method
May 30, 2023
d63c444
remove the unused image parameter from pix2pix0 call method
May 30, 2023
cc7fffc
add test for postprocess for pix2pix0 invert
May 30, 2023
007deda
alt
May 30, 2023
188de72
style
May 30, 2023
eae044f
fix-copies
May 30, 2023
0f01f9c
instruction pix2pix
May 30, 2023
3a95cc6
fix
May 30, 2023
d410a1b
add
May 30, 2023
70904ac
refactor image_processor, add convert_to_rgb
May 30, 2023
48896b3
refactor latent input tests
May 30, 2023
a412fac
remove _accept_image_latents variable
May 30, 2023
6c44410
add height weight optional arguments to image_processor.preprocess
May 31, 2023
1ad9c35
refactor controlnet img2img
May 31, 2023
7bfeee8
style
May 31, 2023
0d10b3b
fix
May 31, 2023
e0b6d96
refactor controlnet
May 31, 2023
b1353fb
refactor controlnet inpaint (only control_image)
May 31, 2023
4a35707
copy
May 31, 2023
53fb437
style
May 31, 2023
958b527
remove _default_height_width
May 31, 2023
02688a5
fix
May 31, 2023
7394a66
Update src/diffusers/pipelines/controlnet/pipeline_controlnet.py
yiyixuxu May 31, 2023
5a7cae0
pass width and height to controlnet pipelines
May 31, 2023
13671d3
style
May 31, 2023
b0f9e4c
Update src/diffusers/image_processor.py
yiyixuxu May 31, 2023
a24e41e
Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
yiyixuxu May 31, 2023
90110ec
Update src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
yiyixuxu May 31, 2023
1c3457e
Update src/diffusers/pipelines/controlnet/pipeline_controlnet.py
yiyixuxu May 31, 2023
76bc363
add doc string, images -> image
May 31, 2023
f8a0705
add types and doc strings for image input arguments that's refactored
May 31, 2023
7af8a1f
alt copy
May 31, 2023
0ca3473
refacotor 4x upscale
Jun 1, 2023
c822a27
refactor latent_upscaler
Jun 1, 2023
ee38476
Apply suggestions from code review
yiyixuxu Jun 1, 2023
2f13a14
add typing
Jun 1, 2023
b02ca59
Revert "refacotor 4x upscale"
Jun 2, 2023
9a5dd1a
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
yiyixuxu Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ class VaeImageProcessor(ConfigMixin):

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from `preprocess` method
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
"""

config_name = CONFIG_NAME
Expand All @@ -49,11 +52,12 @@ def __init__(
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_convert_rgb: bool = False,
):
super().__init__()

@staticmethod
def numpy_to_pil(images):
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
"""
Convert a numpy image or a batch of images to a PIL image.
"""
Expand All @@ -69,7 +73,19 @@ def numpy_to_pil(images):
return pil_images

@staticmethod
def numpy_to_pt(images):
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
Convert a PIL image or a list of PIL images to numpy arrays.
"""
if not isinstance(images, list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to also check if image is of type PIL.Image.Image?

images = [images]
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images = np.stack(images, axis=0)

return images

@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
"""
Convert a numpy image to a pytorch tensor
"""
Expand All @@ -80,7 +96,7 @@ def numpy_to_pt(images):
return images

@staticmethod
def pt_to_numpy(images):
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
"""
Convert a pytorch tensor to a numpy image
"""
Expand All @@ -101,18 +117,39 @@ def denormalize(images):
"""
return (images / 2 + 0.5).clamp(0, 1)

def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts an image to RGB format.
"""
image = image.convert("RGB")
return image

def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images
if height is None:
height = image.height
if width is None:
width = image.width

width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
Comment on lines +142 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

return image

def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
) -> torch.Tensor:
"""
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
Expand All @@ -126,10 +163,11 @@ def preprocess(
)

if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
image = [self.resize(i, height, width) for i in image]
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt

elif isinstance(image[0], np.ndarray):
Expand All @@ -146,7 +184,12 @@ def preprocess(

elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
_, channel, height, width = image.shape

# don't need any preprocess if the image is latents
if channel == 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Good for me!

return image

if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -538,21 +543,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
image = image.to(device=device, dtype=dtype)

batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image

else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective"
f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
)

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)

init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae.config.scaling_factor * init_latents

if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down Expand Up @@ -586,7 +596,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
Expand All @@ -609,9 +626,10 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
Expand Down
Loading