Skip to content

Commit 3105c71

Browse files
[fix] multi t2i adapter set total_downscale_factor (#4621)
* [fix] multi t2i adapter set total_downscale_factor * move image checks into check inputs * remove copied from
1 parent 58f5f74 commit 3105c71

File tree

3 files changed

+301
-27
lines changed

3 files changed

+301
-27
lines changed

src/diffusers/models/adapter.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,31 @@ def __init__(self, adapters: List["T2IAdapter"]):
4141
self.num_adapter = len(adapters)
4242
self.adapters = nn.ModuleList(adapters)
4343

44+
if len(adapters) == 0:
45+
raise ValueError("Expecting at least one adapter")
46+
47+
if len(adapters) == 1:
48+
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
49+
50+
# The outputs from each adapter are added together with a weight
51+
# This means that the change in dimenstions from downsampling must
52+
# be the same for all adapters. Inductively, it also means the total
53+
# downscale factor must also be the same for all adapters.
54+
55+
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
56+
57+
for idx in range(1, len(adapters)):
58+
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor
59+
60+
if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
61+
raise ValueError(
62+
f"Expecting all adapters to have the same total_downscale_factor, "
63+
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
64+
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
65+
)
66+
67+
self.total_downscale_factor = adapters[0].total_downscale_factor
68+
4469
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
4570
r"""
4671
Args:
@@ -56,14 +81,8 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non
5681
else:
5782
adapter_weights = torch.tensor(adapter_weights)
5883

59-
if xs.shape[1] % self.num_adapter != 0:
60-
raise ValueError(
61-
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
62-
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
63-
)
64-
x_list = torch.chunk(xs, self.num_adapter, dim=1)
6584
accume_state = None
66-
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
85+
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
6786
features = adapter(x)
6887
if accume_state is None:
6988
accume_state = features

src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,13 @@ def prepare_extra_step_kwargs(self, generator, eta):
453453
extra_step_kwargs["generator"] = generator
454454
return extra_step_kwargs
455455

456-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
457456
def check_inputs(
458457
self,
459458
prompt,
460459
height,
461460
width,
462461
callback_steps,
462+
image,
463463
negative_prompt=None,
464464
prompt_embeds=None,
465465
negative_prompt_embeds=None,
@@ -501,6 +501,17 @@ def check_inputs(
501501
f" {negative_prompt_embeds.shape}."
502502
)
503503

504+
if isinstance(self.adapter, MultiAdapter):
505+
if not isinstance(image, list):
506+
raise ValueError(
507+
"MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`."
508+
)
509+
510+
if len(image) != len(self.adapter.adapters):
511+
raise ValueError(
512+
f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters."
513+
)
514+
504515
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
505516
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
506517
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
@@ -653,17 +664,19 @@ def __call__(
653664

654665
# 1. Check inputs. Raise error if not correct
655666
self.check_inputs(
656-
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
667+
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
657668
)
658669

659-
is_multi_adapter = isinstance(self.adapter, MultiAdapter)
660-
if is_multi_adapter:
661-
adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
662-
n, c, h, w = adapter_input[0].shape
663-
adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
670+
if isinstance(self.adapter, MultiAdapter):
671+
adapter_input = []
672+
673+
for one_image in image:
674+
one_image = _preprocess_adapter_image(one_image, height, width)
675+
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
676+
adapter_input.append(one_image)
664677
else:
665-
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
666-
adapter_input = adapter_input.to(self.adapter.dtype)
678+
adapter_input = _preprocess_adapter_image(image, height, width)
679+
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
667680

668681
# 2. Define call parameters
669682
if prompt is not None and isinstance(prompt, str):

0 commit comments

Comments
 (0)