From be2402e415c0b69b5689f70ed3f6ab5c81ea663c Mon Sep 17 00:00:00 2001 From: Joseph Coffland Date: Sun, 26 Mar 2023 18:37:49 +0300 Subject: [PATCH] Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603 --- ...eline_stable_diffusion_attend_and_excite.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index ae92ba5526a8..8dd2d3446611 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -14,7 +14,7 @@ import inspect import math -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -75,7 +75,7 @@ def get_empty_store(): def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= 0 and is_cross: - if attn.shape[1] == self.attn_res**2: + if attn.shape[1] == np.prod(self.attn_res): self.step_store[place_in_unet].append(attn) self.cur_att_layer += 1 @@ -97,7 +97,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: attention_maps = self.get_average_attention() for location in from_where: for item in attention_maps[location]: - cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1]) + cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] @@ -108,7 +108,7 @@ def reset(self): self.step_store = self.get_empty_store() self.attention_store = {} - def __init__(self, attn_res=16): + def __init__(self, attn_res): """ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process @@ -715,7 +715,7 @@ def __call__( max_iter_to_alter: int = 25, thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, scale_factor: int = 20, - attn_res: int = 16, + attn_res: Optional[Tuple[int]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -787,8 +787,8 @@ def __call__( Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. scale_factor (`int`, *optional*, default to 20): Scale factor that controls the step size of each Attend and Excite update. - attn_res (`int`, *optional*, default to 16): - The resolution of most semantic attention map. + attn_res (`tuple`, *optional*, default computed from width and height): + The 2D resolution of the semantic attention map. Examples: @@ -861,7 +861,9 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - self.attention_store = AttentionStore(attn_res=attn_res) + if attn_res is None: + attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) + self.attention_store = AttentionStore(attn_res) self.register_attention_control() # default config for step size from original repo