@@ -453,13 +453,13 @@ def prepare_extra_step_kwargs(self, generator, eta):
453453 extra_step_kwargs ["generator" ] = generator
454454 return extra_step_kwargs
455455
456- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
457456 def check_inputs (
458457 self ,
459458 prompt ,
460459 height ,
461460 width ,
462461 callback_steps ,
462+ image ,
463463 negative_prompt = None ,
464464 prompt_embeds = None ,
465465 negative_prompt_embeds = None ,
@@ -501,6 +501,17 @@ def check_inputs(
501501 f" { negative_prompt_embeds .shape } ."
502502 )
503503
504+ if isinstance (self .adapter , MultiAdapter ):
505+ if not isinstance (image , list ):
506+ raise ValueError (
507+ "MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`."
508+ )
509+
510+ if len (image ) != len (self .adapter .adapters ):
511+ raise ValueError (
512+ f"MultiAdapter requires passing the same number of images as adapters. Given { len (image )} images and { len (self .adapter .adapters )} adapters."
513+ )
514+
504515 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
505516 def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
506517 shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
@@ -653,17 +664,19 @@ def __call__(
653664
654665 # 1. Check inputs. Raise error if not correct
655666 self .check_inputs (
656- prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
667+ prompt , height , width , callback_steps , image , negative_prompt , prompt_embeds , negative_prompt_embeds
657668 )
658669
659- is_multi_adapter = isinstance (self .adapter , MultiAdapter )
660- if is_multi_adapter :
661- adapter_input = [_preprocess_adapter_image (img , height , width ).to (device ) for img in image ]
662- n , c , h , w = adapter_input [0 ].shape
663- adapter_input = torch .stack ([x .reshape ([n * c , h , w ]) for x in adapter_input ])
670+ if isinstance (self .adapter , MultiAdapter ):
671+ adapter_input = []
672+
673+ for one_image in image :
674+ one_image = _preprocess_adapter_image (one_image , height , width )
675+ one_image = one_image .to (device = device , dtype = self .adapter .dtype )
676+ adapter_input .append (one_image )
664677 else :
665- adapter_input = _preprocess_adapter_image (image , height , width ). to ( device )
666- adapter_input = adapter_input .to (self .adapter .dtype )
678+ adapter_input = _preprocess_adapter_image (image , height , width )
679+ adapter_input = adapter_input .to (device = device , dtype = self .adapter .dtype )
667680
668681 # 2. Define call parameters
669682 if prompt is not None and isinstance (prompt , str ):
0 commit comments