Skip to content

Commit d01a0d5

Browse files
committed
feat: switch examples to use image arg
1 parent 64e48b7 commit d01a0d5

File tree

4 files changed

+63
-63
lines changed

4 files changed

+63
-63
lines changed

examples/community/imagic_stable_diffusion.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def disable_attention_slicing(self):
133133
def train(
134134
self,
135135
prompt: Union[str, List[str]],
136-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
136+
image: Union[torch.FloatTensor, PIL.Image.Image],
137137
height: Optional[int] = 512,
138138
width: Optional[int] = 512,
139139
generator: Optional[torch.Generator] = None,
@@ -241,14 +241,14 @@ def train(
241241
lr=embedding_learning_rate,
242242
)
243243

244-
if isinstance(init_image, PIL.Image.Image):
245-
init_image = preprocess(init_image)
244+
if isinstance(image, PIL.Image.Image):
245+
image = preprocess(image)
246246

247247
latents_dtype = text_embeddings.dtype
248-
init_image = init_image.to(device=self.device, dtype=latents_dtype)
249-
init_latent_image_dist = self.vae.encode(init_image).latent_dist
250-
init_image_latents = init_latent_image_dist.sample(generator=generator)
251-
init_image_latents = 0.18215 * init_image_latents
248+
image = image.to(device=self.device, dtype=latents_dtype)
249+
init_latent_image_dist = self.vae.encode(image).latent_dist
250+
image_latents = init_latent_image_dist.sample(generator=generator)
251+
image_latents = 0.18215 * image_latents
252252

253253
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
254254
progress_bar.set_description("Steps")
@@ -259,12 +259,12 @@ def train(
259259
for _ in range(text_embedding_optimization_steps):
260260
with accelerator.accumulate(text_embeddings):
261261
# Sample noise that we'll add to the latents
262-
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
263-
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
262+
noise = torch.randn(image_latents.shape).to(image_latents.device)
263+
timesteps = torch.randint(1000, (1,), device=image_latents.device)
264264

265265
# Add noise to the latents according to the noise magnitude at each timestep
266266
# (this is the forward diffusion process)
267-
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
267+
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
268268

269269
# Predict the noise residual
270270
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
@@ -301,12 +301,12 @@ def train(
301301
for _ in range(model_fine_tuning_optimization_steps):
302302
with accelerator.accumulate(self.unet.parameters()):
303303
# Sample noise that we'll add to the latents
304-
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
305-
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
304+
noise = torch.randn(image_latents.shape).to(image_latents.device)
305+
timesteps = torch.randint(1000, (1,), device=image_latents.device)
306306

307307
# Add noise to the latents according to the noise magnitude at each timestep
308308
# (this is the forward diffusion process)
309-
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
309+
noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
310310

311311
# Predict the noise residual
312312
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample

examples/community/lpw_stable_diffusion.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def __call__(
555555
self,
556556
prompt: Union[str, List[str]],
557557
negative_prompt: Optional[Union[str, List[str]]] = None,
558-
init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
558+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
559559
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
560560
height: int = 512,
561561
width: int = 512,
@@ -583,11 +583,11 @@ def __call__(
583583
negative_prompt (`str` or `List[str]`, *optional*):
584584
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
585585
if `guidance_scale` is less than `1`).
586-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
586+
image (`torch.FloatTensor` or `PIL.Image.Image`):
587587
`Image`, or tensor representing an image batch, that will be used as the starting point for the
588588
process.
589589
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
590-
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
590+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
591591
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
592592
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
593593
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -605,11 +605,11 @@ def __call__(
605605
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
606606
usually at the expense of lower image quality.
607607
strength (`float`, *optional*, defaults to 0.8):
608-
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
609-
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
608+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
609+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
610610
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
611611
noise will be maximum and the denoising process will run for the full number of iterations specified in
612-
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
612+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
613613
num_images_per_prompt (`int`, *optional*, defaults to 1):
614614
The number of images to generate per prompt.
615615
eta (`float`, *optional*, defaults to 0.0):
@@ -714,7 +714,7 @@ def __call__(
714714
mask = None
715715
noise = None
716716

717-
if init_image is None:
717+
if image is None:
718718
# get the initial random noise unless the user supplied it
719719

720720
# Unlike in other pipelines, latents need to be generated in the target device
@@ -753,11 +753,11 @@ def __call__(
753753
# scale the initial noise by the standard deviation required by the scheduler
754754
latents = latents * self.scheduler.init_noise_sigma
755755
else:
756-
if isinstance(init_image, PIL.Image.Image):
757-
init_image = preprocess_image(init_image)
756+
if isinstance(image, PIL.Image.Image):
757+
image = preprocess_image(image)
758758
# encode the init image into latents and scale the latents
759-
init_image = init_image.to(device=self.device, dtype=latents_dtype)
760-
init_latent_dist = self.vae.encode(init_image).latent_dist
759+
image = image.to(device=self.device, dtype=latents_dtype)
760+
init_latent_dist = self.vae.encode(image).latent_dist
761761
init_latents = init_latent_dist.sample(generator=generator)
762762
init_latents = 0.18215 * init_latents
763763
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
@@ -772,7 +772,7 @@ def __call__(
772772

773773
# check sizes
774774
if not mask.shape == init_latents.shape:
775-
raise ValueError("The mask and init_image should be the same size!")
775+
raise ValueError("The mask and image should be the same size!")
776776

777777
# get the original timestep using init_timestep
778778
offset = self.scheduler.config.get("steps_offset", 0)
@@ -961,7 +961,7 @@ def text2img(
961961

962962
def img2img(
963963
self,
964-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
964+
image: Union[torch.FloatTensor, PIL.Image.Image],
965965
prompt: Union[str, List[str]],
966966
negative_prompt: Optional[Union[str, List[str]]] = None,
967967
strength: float = 0.8,
@@ -980,7 +980,7 @@ def img2img(
980980
r"""
981981
Function for image-to-image generation.
982982
Args:
983-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
983+
image (`torch.FloatTensor` or `PIL.Image.Image`):
984984
`Image`, or tensor representing an image batch, that will be used as the starting point for the
985985
process.
986986
prompt (`str` or `List[str]`):
@@ -989,11 +989,11 @@ def img2img(
989989
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
990990
if `guidance_scale` is less than `1`).
991991
strength (`float`, *optional*, defaults to 0.8):
992-
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
993-
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
992+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
993+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
994994
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
995995
noise will be maximum and the denoising process will run for the full number of iterations specified in
996-
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
996+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
997997
num_inference_steps (`int`, *optional*, defaults to 50):
998998
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
999999
expense of slower inference. This parameter will be modulated by `strength`.
@@ -1035,7 +1035,7 @@ def img2img(
10351035
return self.__call__(
10361036
prompt=prompt,
10371037
negative_prompt=negative_prompt,
1038-
init_image=init_image,
1038+
image=image,
10391039
num_inference_steps=num_inference_steps,
10401040
guidance_scale=guidance_scale,
10411041
strength=strength,
@@ -1052,7 +1052,7 @@ def img2img(
10521052

10531053
def inpaint(
10541054
self,
1055-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
1055+
image: Union[torch.FloatTensor, PIL.Image.Image],
10561056
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
10571057
prompt: Union[str, List[str]],
10581058
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -1072,11 +1072,11 @@ def inpaint(
10721072
r"""
10731073
Function for inpaint.
10741074
Args:
1075-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
1075+
image (`torch.FloatTensor` or `PIL.Image.Image`):
10761076
`Image`, or tensor representing an image batch, that will be used as the starting point for the
10771077
process. This is the image whose masked region will be inpainted.
10781078
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1079-
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
1079+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
10801080
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
10811081
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
10821082
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -1088,7 +1088,7 @@ def inpaint(
10881088
strength (`float`, *optional*, defaults to 0.8):
10891089
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
10901090
is 1, the denoising process will be run on the masked area for the full number of iterations specified
1091-
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
1091+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
10921092
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
10931093
num_inference_steps (`int`, *optional*, defaults to 50):
10941094
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
@@ -1131,7 +1131,7 @@ def inpaint(
11311131
return self.__call__(
11321132
prompt=prompt,
11331133
negative_prompt=negative_prompt,
1134-
init_image=init_image,
1134+
image=image,
11351135
mask_image=mask_image,
11361136
num_inference_steps=num_inference_steps,
11371137
guidance_scale=guidance_scale,

0 commit comments

Comments
 (0)