Skip to content

Commit d4f0742

Browse files
authored
Standardize on using image argument in all pipelines (huggingface#1361)
* feat: switch core pipelines to use image arg * test: update tests for core pipelines * feat: switch examples to use image arg * docs: update docs to use image arg * style: format code using black and doc-builder * fix: deprecate use of init_image in all pipelines
1 parent 1f517a5 commit d4f0742

9 files changed

+138
-107
lines changed

pipelines/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ init_image = init_image.resize((768, 512))
126126

127127
prompt = "A fantasy landscape, trending on artstation"
128128

129-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
129+
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
130130

131131
images[0].save("fantasy_landscape.png")
132132
```

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -435,26 +435,26 @@ def get_timesteps(self, num_inference_steps, strength, device):
435435

436436
return timesteps, num_inference_steps - t_start
437437

438-
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
439-
init_image = init_image.to(device=device, dtype=dtype)
440-
init_latent_dist = self.vae.encode(init_image).latent_dist
438+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
439+
image = image.to(device=device, dtype=dtype)
440+
init_latent_dist = self.vae.encode(image).latent_dist
441441
init_latents = init_latent_dist.sample(generator=generator)
442442
init_latents = 0.18215 * init_latents
443443

444444
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
445445
# expand init_latents for batch_size
446446
deprecation_message = (
447447
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
448-
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
448+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
449449
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
450-
" your script to pass as many init images as text prompts to suppress this warning."
450+
" your script to pass as many initial images as text prompts to suppress this warning."
451451
)
452-
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
452+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
453453
additional_image_per_prompt = batch_size // init_latents.shape[0]
454454
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
455455
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
456456
raise ValueError(
457-
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
457+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
458458
)
459459
else:
460460
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
@@ -472,7 +472,7 @@ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_promp
472472
def __call__(
473473
self,
474474
prompt: Union[str, List[str]],
475-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
475+
image: Union[torch.FloatTensor, PIL.Image.Image],
476476
strength: float = 0.8,
477477
num_inference_steps: Optional[int] = 50,
478478
guidance_scale: Optional[float] = 7.5,
@@ -484,22 +484,23 @@ def __call__(
484484
return_dict: bool = True,
485485
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
486486
callback_steps: Optional[int] = 1,
487+
**kwargs,
487488
):
488489
r"""
489490
Function invoked when calling the pipeline for generation.
490491
491492
Args:
492493
prompt (`str` or `List[str]`):
493494
The prompt or prompts to guide the image generation.
494-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
495+
image (`torch.FloatTensor` or `PIL.Image.Image`):
495496
`Image`, or tensor representing an image batch, that will be used as the starting point for the
496497
process.
497498
strength (`float`, *optional*, defaults to 0.8):
498-
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
499-
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
500-
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
501-
noise will be maximum and the denoising process will run for the full number of iterations specified in
502-
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
499+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
500+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
501+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
502+
be maximum and the denoising process will run for the full number of iterations specified in
503+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
503504
num_inference_steps (`int`, *optional*, defaults to 50):
504505
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
505506
expense of slower inference. This parameter will be modulated by `strength`.
@@ -540,6 +541,10 @@ def __call__(
540541
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
541542
(nsfw) content, according to the `safety_checker`.
542543
"""
544+
message = "Please use `image` instead of `init_image`."
545+
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
546+
image = init_image or image
547+
543548
# 1. Check inputs
544549
self.check_inputs(prompt, strength, callback_steps)
545550

@@ -557,8 +562,8 @@ def __call__(
557562
)
558563

559564
# 4. Preprocess image
560-
if isinstance(init_image, PIL.Image.Image):
561-
init_image = preprocess(init_image)
565+
if isinstance(image, PIL.Image.Image):
566+
image = preprocess(image)
562567

563568
# 5. set timesteps
564569
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -567,7 +572,7 @@ def __call__(
567572

568573
# 6. Prepare latent variables
569574
latents = self.prepare_latents(
570-
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
575+
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
571576
)
572577

573578
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline

pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
LMSDiscreteScheduler,
1818
PNDMScheduler,
1919
)
20-
from ...utils import PIL_INTERPOLATION
20+
from ...utils import PIL_INTERPOLATION, deprecate
2121

2222

2323
def preprocess(image):
@@ -66,7 +66,7 @@ def __init__(
6666
@torch.no_grad()
6767
def __call__(
6868
self,
69-
init_image: Union[torch.Tensor, PIL.Image.Image],
69+
image: Union[torch.Tensor, PIL.Image.Image],
7070
batch_size: Optional[int] = 1,
7171
num_inference_steps: Optional[int] = 100,
7272
eta: Optional[float] = 0.0,
@@ -77,7 +77,7 @@ def __call__(
7777
) -> Union[Tuple, ImagePipelineOutput]:
7878
r"""
7979
Args:
80-
init_image (`torch.Tensor` or `PIL.Image.Image`):
80+
image (`torch.Tensor` or `PIL.Image.Image`):
8181
`Image`, or tensor representing an image batch, that will be used as the starting point for the
8282
process.
8383
batch_size (`int`, *optional*, defaults to 1):
@@ -102,20 +102,21 @@ def __call__(
102102
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
103103
generated images.
104104
"""
105+
message = "Please use `image` instead of `init_image`."
106+
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
107+
image = init_image or image
105108

106-
if isinstance(init_image, PIL.Image.Image):
109+
if isinstance(image, PIL.Image.Image):
107110
batch_size = 1
108-
elif isinstance(init_image, torch.Tensor):
109-
batch_size = init_image.shape[0]
111+
elif isinstance(image, torch.Tensor):
112+
batch_size = image.shape[0]
110113
else:
111-
raise ValueError(
112-
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
113-
)
114+
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
114115

115-
if isinstance(init_image, PIL.Image.Image):
116-
init_image = preprocess(init_image)
116+
if isinstance(image, PIL.Image.Image):
117+
image = preprocess(image)
117118

118-
height, width = init_image.shape[-2:]
119+
height, width = image.shape[-2:]
119120

120121
# in_channels should be 6: 3 for latents, 3 for low resolution image
121122
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
@@ -128,7 +129,7 @@ def __call__(
128129
else:
129130
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
130131

131-
init_image = init_image.to(device=self.device, dtype=latents_dtype)
132+
image = image.to(device=self.device, dtype=latents_dtype)
132133

133134
# set timesteps and move to the correct device
134135
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
@@ -148,7 +149,7 @@ def __call__(
148149

149150
for t in self.progress_bar(timesteps_tensor):
150151
# concat latents and low resolution image in the channel dimension.
151-
latents_input = torch.cat([latents, init_image], dim=1)
152+
latents_input = torch.cat([latents, image], dim=1)
152153
latents_input = self.scheduler.scale_model_input(latents_input, t)
153154
# predict the noise residual
154155
noise_pred = self.unet(latents_input, t).sample

pipelines/stable_diffusion/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ prompt = "An astronaut riding an elephant"
138138
image = pipe(
139139
prompt=prompt,
140140
source_prompt=source_prompt,
141-
init_image=init_image,
141+
image=init_image,
142142
num_inference_steps=100,
143143
eta=0.1,
144144
strength=0.8,
@@ -164,7 +164,7 @@ torch.manual_seed(0)
164164
image = pipe(
165165
prompt=prompt,
166166
source_prompt=source_prompt,
167-
init_image=init_image,
167+
image=init_image,
168168
num_inference_steps=100,
169169
eta=0.1,
170170
strength=0.85,

pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -477,26 +477,26 @@ def get_timesteps(self, num_inference_steps, strength, device):
477477

478478
return timesteps, num_inference_steps - t_start
479479

480-
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
481-
init_image = init_image.to(device=device, dtype=dtype)
482-
init_latent_dist = self.vae.encode(init_image).latent_dist
480+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
481+
image = image.to(device=device, dtype=dtype)
482+
init_latent_dist = self.vae.encode(image).latent_dist
483483
init_latents = init_latent_dist.sample(generator=generator)
484484
init_latents = 0.18215 * init_latents
485485

486486
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
487487
# expand init_latents for batch_size
488488
deprecation_message = (
489489
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
490-
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
490+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
491491
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
492-
" your script to pass as many init images as text prompts to suppress this warning."
492+
" your script to pass as many initial images as text prompts to suppress this warning."
493493
)
494-
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
494+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
495495
additional_image_per_prompt = batch_size // init_latents.shape[0]
496496
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
497497
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
498498
raise ValueError(
499-
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
499+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
500500
)
501501
else:
502502
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
@@ -516,7 +516,7 @@ def __call__(
516516
self,
517517
prompt: Union[str, List[str]],
518518
source_prompt: Union[str, List[str]],
519-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
519+
image: Union[torch.FloatTensor, PIL.Image.Image],
520520
strength: float = 0.8,
521521
num_inference_steps: Optional[int] = 50,
522522
guidance_scale: Optional[float] = 7.5,
@@ -528,22 +528,23 @@ def __call__(
528528
return_dict: bool = True,
529529
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
530530
callback_steps: Optional[int] = 1,
531+
**kwargs,
531532
):
532533
r"""
533534
Function invoked when calling the pipeline for generation.
534535
535536
Args:
536537
prompt (`str` or `List[str]`):
537538
The prompt or prompts to guide the image generation.
538-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
539+
image (`torch.FloatTensor` or `PIL.Image.Image`):
539540
`Image`, or tensor representing an image batch, that will be used as the starting point for the
540541
process.
541542
strength (`float`, *optional*, defaults to 0.8):
542-
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
543-
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
544-
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
545-
noise will be maximum and the denoising process will run for the full number of iterations specified in
546-
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
543+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
544+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
545+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
546+
be maximum and the denoising process will run for the full number of iterations specified in
547+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
547548
num_inference_steps (`int`, *optional*, defaults to 50):
548549
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
549550
expense of slower inference. This parameter will be modulated by `strength`.
@@ -584,6 +585,10 @@ def __call__(
584585
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
585586
(nsfw) content, according to the `safety_checker`.
586587
"""
588+
message = "Please use `image` instead of `init_image`."
589+
init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
590+
image = init_image or image
591+
587592
# 1. Check inputs
588593
self.check_inputs(prompt, strength, callback_steps)
589594

@@ -602,8 +607,8 @@ def __call__(
602607
)
603608

604609
# 4. Preprocess image
605-
if isinstance(init_image, PIL.Image.Image):
606-
init_image = preprocess(init_image)
610+
if isinstance(image, PIL.Image.Image):
611+
image = preprocess(image)
607612

608613
# 5. Prepare timesteps
609614
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -612,7 +617,7 @@ def __call__(
612617

613618
# 6. Prepare latent variables
614619
latents, clean_latents = self.prepare_latents(
615-
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
620+
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
616621
)
617622
source_latents = latents
618623

0 commit comments

Comments
 (0)