Skip to content

Attend-And-Excitte Pipeline can not generate non squared images. #2603

@JanSoltysik

Description

@JanSoltysik

Describe the bug

While using StableDiffusionAttendAndExcitePipeline class.
I was unable to generate non-square images, changing to the older weights from 2-1 did not help to resolve the problem.
Or the solution proposed in pr, I think did not help with different values of height and weight.
If it is because of the method itself, in which case sorry for missunderstending.

Reproduction

I've build diffusers from source and executed following snipet:

import diffusers
model_id = "stabilityai/stable-diffusion-2-1-base"
ae_pipe = diffusers.StableDiffusionAttendAndExcitePipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.float16
).to("cuda")
generator = torch.Generator().manual_seed(0)

images = ae_pipe(
        prompt="a photo of lion with snickers",
        guidance_scale=7.5,
        generator=generator,
        width=1024,
        height=512,
        num_inference_steps=50,
        token_indices=[4, 6]
).images

Logs

The above generations produces following trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [10], line 3
      1 generator = torch.Generator().manual_seed(0)
----> 3 images = ae_pipe(
      4         prompt="a photo of lion with snickers",
      5         guidance_scale=7.5,
      6         generator=generator,
      7         width=1024,
      8         height=512,
      9         num_inference_steps=50,
     10         token_indices=[4, 6]
     11 ).images

File ~/stable_diff/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/stable_diff/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py:907, in StableDiffusionAttendAndExcitePipeline.__call__(self, prompt, token_indices, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, max_iter_to_alter, thresholds, scale_factor, attn_res)
    904 self.unet.zero_grad()
    906 # Get max activation value for each subject token
--> 907 max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
    908     indices=index,
    909 )
    911 loss = self._compute_loss(max_attention_per_index=max_attention_per_index)
    913 # If this is an iterative refinement step, verify we have reached the desired threshold for all

File ~/stable_diff/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py:591, in StableDiffusionAttendAndExcitePipeline._aggregate_and_get_max_attention_per_token(self, indices)
    586 def _aggregate_and_get_max_attention_per_token(
    587     self,
    588     indices: List[int],
    589 ):
    590     """Aggregates the attention for each token and computes the max activation value for each token to alter."""
--> 591     attention_maps = self.attention_store.aggregate_attention(
    592         from_where=("up", "down", "mid"),
    593     )
    594     max_attention_per_index = self._compute_max_attention_per_index(
    595         attention_maps=attention_maps,
    596         indices=indices,
    597     )
    598     return max_attention_per_index

File ~/stable_diff/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py:102, in AttentionStore.aggregate_attention(self, from_where)
    100         cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
    101         out.append(cross_maps)
--> 102 out = torch.cat(out, dim=0)
    103 out = out.sum(0) / out.shape[0]
    104 return out

RuntimeError: torch.cat(): expected a non-empty list of Tensors


### System Info

- `diffusers` version: 0.15.0.dev0
- Platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35
- Python version: 3.10.6
- PyTorch version (GPU?): 1.13.1+cu117 (True)
- Huggingface_hub version: 0.12.1
- Transformers version: 4.26.1
- Accelerate version: 0.16.0
- xFormers version: 0.0.16
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions