Skip to content

Commit 2db180d

Browse files
authored
Make img2img strength 1 behave the same as txt2img (#2895)
* Fix img2img and inpainting code so a strength of 1 behaves the same as txt2img. * Make generated images identical to their txt2img counterparts when strength is 1.
1 parent d232a43 commit 2db180d

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

invokeai/backend/generator/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def generate(
9999
h_symmetry_time_pct=h_symmetry_time_pct,
100100
v_symmetry_time_pct=v_symmetry_time_pct,
101101
attention_maps_callback=attention_maps_callback,
102+
seed=seed,
102103
**kwargs,
103104
)
104105
results = []
@@ -289,9 +290,7 @@ def generate_initial_noise(self, seed, width, height):
289290
if self.variation_amount > 0:
290291
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
291292
seed = random.randrange(0, np.iinfo(np.uint32).max)
292-
return (seed, initial_noise)
293-
else:
294-
return (seed, None)
293+
return (seed, initial_noise)
295294

296295
# returns a tensor filled with random numbers from a normal distribution
297296
def get_noise(self, width, height):

invokeai/backend/generator/img2img.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""
22
invokeai.backend.generator.img2img descends from .generator
33
"""
4+
from typing import Optional
45

56
import torch
7+
from accelerate.utils import set_seed
68
from diffusers import logging
79

810
from ..stable_diffusion import (
@@ -35,6 +37,7 @@ def get_make_image(
3537
h_symmetry_time_pct=None,
3638
v_symmetry_time_pct=None,
3739
attention_maps_callback=None,
40+
seed=None,
3841
**kwargs,
3942
):
4043
"""
@@ -65,6 +68,7 @@ def make_image(x_T):
6568
# FIXME: use x_T for initial seeded noise
6669
# We're not at the moment because the pipeline automatically resizes init_image if
6770
# necessary, which the x_T input might not match.
71+
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
6872
logging.set_verbosity_error() # quench safety check warnings
6973
pipeline_output = pipeline.img2img_from_embeddings(
7074
init_image,
@@ -73,6 +77,7 @@ def make_image(x_T):
7377
conditioning_data,
7478
noise_func=self.get_noise_like,
7579
callback=step_callback,
80+
seed=seed
7681
)
7782
if (
7883
pipeline_output.attention_map_saver is not None
@@ -83,7 +88,9 @@ def make_image(x_T):
8388

8489
return make_image
8590

86-
def get_noise_like(self, like: torch.Tensor):
91+
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
92+
if seed is not None:
93+
set_seed(seed)
8794
device = like.device
8895
if device.type == "mps":
8996
x = torch.randn_like(like, device="cpu").to(device)

invokeai/backend/generator/inpaint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def get_make_image(
223223
inpaint_height=None,
224224
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
225225
attention_maps_callback=None,
226+
seed=None,
226227
**kwargs,
227228
):
228229
"""
@@ -319,6 +320,7 @@ def make_image(x_T):
319320
conditioning_data=conditioning_data,
320321
noise_func=self.get_noise_like,
321322
callback=step_callback,
323+
seed=seed
322324
)
323325

324326
if (

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ def img2img_from_embeddings(
690690
callback: Callable[[PipelineIntermediateState], None] = None,
691691
run_id=None,
692692
noise_func=None,
693+
seed=None,
693694
) -> InvokeAIStableDiffusionPipelineOutput:
694695
if isinstance(init_image, PIL.Image.Image):
695696
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
@@ -703,7 +704,7 @@ def img2img_from_embeddings(
703704
device=self._model_group.device_for(self.unet),
704705
dtype=self.unet.dtype,
705706
)
706-
noise = noise_func(initial_latents)
707+
noise = noise_func(initial_latents, seed)
707708

708709
return self.img2img_from_latents_and_embeddings(
709710
initial_latents,
@@ -731,9 +732,11 @@ def img2img_from_latents_and_embeddings(
731732
device=self._model_group.device_for(self.unet),
732733
)
733734
result_latents, result_attention_maps = self.latents_from_embeddings(
734-
initial_latents,
735-
num_inference_steps,
736-
conditioning_data,
735+
latents=initial_latents if strength < 1.0 else torch.zeros_like(
736+
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
737+
),
738+
num_inference_steps=num_inference_steps,
739+
conditioning_data=conditioning_data,
737740
timesteps=timesteps,
738741
noise=noise,
739742
run_id=run_id,
@@ -779,6 +782,7 @@ def inpaint_from_embeddings(
779782
callback: Callable[[PipelineIntermediateState], None] = None,
780783
run_id=None,
781784
noise_func=None,
785+
seed=None,
782786
) -> InvokeAIStableDiffusionPipelineOutput:
783787
device = self._model_group.device_for(self.unet)
784788
latents_dtype = self.unet.dtype
@@ -802,7 +806,7 @@ def inpaint_from_embeddings(
802806
init_image_latents = self.non_noised_latents_from_image(
803807
init_image, device=device, dtype=latents_dtype
804808
)
805-
noise = noise_func(init_image_latents)
809+
noise = noise_func(init_image_latents, seed)
806810

807811
if mask.dim() == 3:
808812
mask = mask.unsqueeze(0)
@@ -831,9 +835,11 @@ def inpaint_from_embeddings(
831835

832836
try:
833837
result_latents, result_attention_maps = self.latents_from_embeddings(
834-
init_image_latents,
835-
num_inference_steps,
836-
conditioning_data,
838+
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
839+
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
840+
),
841+
num_inference_steps=num_inference_steps,
842+
conditioning_data=conditioning_data,
837843
noise=noise,
838844
timesteps=timesteps,
839845
additional_guidance=guidance,

0 commit comments

Comments
 (0)