Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions src/diffusers/models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ def __init__(self, adapters: List["T2IAdapter"]):
self.num_adapter = len(adapters)
self.adapters = nn.ModuleList(adapters)

if len(adapters) == 0:
raise ValueError("Expecting at least one adapter")

if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

# The outputs from each adapter are added together with a weight
# This means that the change in dimenstions from downsampling must
# be the same for all adapters. Inductively, it also means the total
# downscale factor must also be the same for all adapters.

first_adapter_total_downscale_factor = adapters[0].total_downscale_factor

for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor

if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, "
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
)

self.total_downscale_factor = adapters[0].total_downscale_factor

def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Args:
Expand All @@ -56,14 +81,8 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non
else:
adapter_weights = torch.tensor(adapter_weights)

if xs.shape[1] % self.num_adapter != 0:
raise ValueError(
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
)
x_list = torch.chunk(xs, self.num_adapter, dim=1)
Comment on lines -59 to -64
Copy link
Contributor Author

@williamberman williamberman Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the pipeline was squashing all the different images for the different adapters into one tensor and then re-splitting here. Sorry, I should have caught this in code review

accume_state = None
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
features = adapter(x)
if accume_state is None:
accume_state = features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,13 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
image,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
Expand Down Expand Up @@ -501,6 +501,17 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

if isinstance(self.adapter, MultiAdapter):
if not isinstance(image, list):
raise ValueError(
"MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`."
)

if len(image) != len(self.adapter.adapters):
raise ValueError(
f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
Expand Down Expand Up @@ -653,17 +664,19 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
)

is_multi_adapter = isinstance(self.adapter, MultiAdapter)
if is_multi_adapter:
adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
n, c, h, w = adapter_input[0].shape
adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
if isinstance(self.adapter, MultiAdapter):
adapter_input = []

for one_image in image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
adapter_input = adapter_input.to(self.adapter.dtype)
adapter_input = _preprocess_adapter_image(image, height, width)
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down
Loading