diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py index a65a3873b130..b9ffc64d912f 100644 --- a/src/diffusers/models/adapter.py +++ b/src/diffusers/models/adapter.py @@ -41,6 +41,31 @@ def __init__(self, adapters: List["T2IAdapter"]): self.num_adapter = len(adapters) self.adapters = nn.ModuleList(adapters) + if len(adapters) == 0: + raise ValueError("Expecting at least one adapter") + + if len(adapters) == 1: + raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") + + # The outputs from each adapter are added together with a weight + # This means that the change in dimenstions from downsampling must + # be the same for all adapters. Inductively, it also means the total + # downscale factor must also be the same for all adapters. + + first_adapter_total_downscale_factor = adapters[0].total_downscale_factor + + for idx in range(1, len(adapters)): + adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor + + if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor: + raise ValueError( + f"Expecting all adapters to have the same total_downscale_factor, " + f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and " + f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}" + ) + + self.total_downscale_factor = adapters[0].total_downscale_factor + def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: r""" Args: @@ -56,14 +81,8 @@ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = Non else: adapter_weights = torch.tensor(adapter_weights) - if xs.shape[1] % self.num_adapter != 0: - raise ValueError( - f"Expecting multi-adapter's input have number of channel that cab be evenly divisible " - f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0" - ) - x_list = torch.chunk(xs, self.num_adapter, dim=1) accume_state = None - for x, w, adapter in zip(x_list, adapter_weights, self.adapters): + for x, w, adapter in zip(xs, adapter_weights, self.adapters): features = adapter(x) if accume_state is None: accume_state = features diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 4aa911198a2e..1ee6f9296d5a 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -453,13 +453,13 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, + image, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -501,6 +501,17 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if isinstance(self.adapter, MultiAdapter): + if not isinstance(image, list): + raise ValueError( + "MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`." + ) + + if len(image) != len(self.adapter.adapters): + raise ValueError( + f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -653,17 +664,19 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds ) - is_multi_adapter = isinstance(self.adapter, MultiAdapter) - if is_multi_adapter: - adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image] - n, c, h, w = adapter_input[0].shape - adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input]) + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) else: - adapter_input = _preprocess_adapter_image(image, height, width).to(device) - adapter_input = adapter_input.to(self.adapter.dtype) + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py index 0c1dd1cfe87b..a4f522062e34 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py @@ -21,19 +21,21 @@ import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, + MultiAdapter, PNDMScheduler, StableDiffusionAdapterPipeline, T2IAdapter, UNet2DConditionModel, ) -from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils import floats_tensor, load_image, load_numpy, logging, slow, torch_device from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference enable_full_determinism() @@ -82,13 +84,38 @@ def get_dummy_components(self, adapter_type): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") torch.manual_seed(0) - adapter = T2IAdapter( - in_channels=3, - channels=[32, 64], - num_res_blocks=2, - downscale_factor=2, - adapter_type=adapter_type, - ) + + if adapter_type == "full_adapter" or adapter_type == "light_adapter": + adapter = T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=2, + adapter_type=adapter_type, + ) + elif adapter_type == "multi_adapter": + adapter = MultiAdapter( + [ + T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=2, + adapter_type="full_adapter", + ), + T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=2, + adapter_type="full_adapter", + ), + ] + ) + else: + raise ValueError( + f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''" + ) components = { "adapter": adapter, @@ -102,8 +129,12 @@ def get_dummy_components(self, adapter_type): } return components - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + def get_dummy_inputs(self, device, seed=0, num_images=1): + if num_images == 1: + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + else: + image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)] + if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -172,6 +203,217 @@ def test_stable_diffusion_adapter_default_case(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 +class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): + def get_dummy_components(self): + return super().get_dummy_components("multi_adapter") + + def get_dummy_inputs(self, device, seed=0): + return super().get_dummy_inputs(device, seed, num_images=2) + + def test_stable_diffusion_adapter_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionAdapterPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4902, 0.5539, 0.4317, 0.4682, 0.6190, 0.4351, 0.5018, 0.5046, 0.4772]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + + def test_inference_batch_consistent( + self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"] + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + for batch_size in batch_sizes: + batched_inputs = {} + for name, value in inputs.items(): + if name in self.batch_params: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 100 * "very long" + elif name == "image": + batched_images = [] + + for image in value: + batched_images.append(batch_size * [image]) + + batched_inputs[name] = batched_images + else: + batched_inputs[name] = batch_size * [value] + + elif name == "batch_size": + batched_inputs[name] = batch_size + else: + batched_inputs[name] = value + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + batched_inputs["output_type"] = "np" + + if self.pipeline_class.__name__ == "DanceDiffusionPipeline": + batched_inputs.pop("output_type") + + output = pipe(**batched_inputs) + + assert len(output[0]) == batch_size + + batched_inputs["output_type"] = "np" + + if self.pipeline_class.__name__ == "DanceDiffusionPipeline": + batched_inputs.pop("output_type") + + output = pipe(**batched_inputs)[0] + + assert output.shape[0] == batch_size + + logger.setLevel(level=diffusers.logging.WARNING) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + if key == "image": + batched_images = [] + + for image in inputs[key]: + batched_images.append(batch_size * [image]) + + inputs[key] = batched_images + else: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + def test_inference_batch_single_identical( + self, + batch_size=3, + test_max_difference=None, + test_mean_pixel_difference=None, + relax_max_difference=False, + expected_max_diff=2e-3, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + if test_max_difference is None: + # TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems + # make sure that batched and non-batched is identical + test_max_difference = torch_device != "mps" + + if test_mean_pixel_difference is None: + # TODO same as above + test_mean_pixel_difference = torch_device != "mps" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batch_size = batch_size + for name, value in inputs.items(): + if name in self.batch_params: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 100 * "very long" + elif name == "image": + batched_images = [] + + for image in value: + batched_images.append(batch_size * [image]) + + batched_inputs[name] = batched_images + else: + batched_inputs[name] = batch_size * [value] + elif name == "batch_size": + batched_inputs[name] = batch_size + elif name == "generator": + batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)] + else: + batched_inputs[name] = value + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + if self.pipeline_class.__name__ != "DanceDiffusionPipeline": + batched_inputs["output_type"] = "np" + + output_batch = pipe(**batched_inputs) + assert output_batch[0].shape[0] == batch_size + + inputs["generator"] = self.get_generator(0) + + output = pipe(**inputs) + + logger.setLevel(level=diffusers.logging.WARNING) + if test_max_difference: + if relax_max_difference: + # Taking the median of the largest differences + # is resilient to outliers + diff = np.abs(output_batch[0][0] - output[0][0]) + diff = diff.flatten() + diff.sort() + max_diff = np.median(diff[-5:]) + else: + max_diff = np.abs(output_batch[0][0] - output[0][0]).max() + assert max_diff < expected_max_diff + + if test_mean_pixel_difference: + assert_mean_pixel_difference(output_batch[0][0], output[0][0]) + + # We do not support saving pipelines with multiple adapters. The multiple adapters should be saved as their + # own independent pipelines + + def test_save_load_local(self): + ... + + def test_save_load_optional_components(self): + ... + + @slow @require_torch_gpu class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):