Skip to content

Commit be2402e

Browse files
committed
Allow stable diffusion attend and excite pipeline to work with any size output image. Re: huggingface#2476, huggingface#2603
1 parent 40a7b86 commit be2402e

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
import math
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -75,7 +75,7 @@ def get_empty_store():
7575

7676
def __call__(self, attn, is_cross: bool, place_in_unet: str):
7777
if self.cur_att_layer >= 0 and is_cross:
78-
if attn.shape[1] == self.attn_res**2:
78+
if attn.shape[1] == np.prod(self.attn_res):
7979
self.step_store[place_in_unet].append(attn)
8080

8181
self.cur_att_layer += 1
@@ -97,7 +97,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
9797
attention_maps = self.get_average_attention()
9898
for location in from_where:
9999
for item in attention_maps[location]:
100-
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
100+
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
101101
out.append(cross_maps)
102102
out = torch.cat(out, dim=0)
103103
out = out.sum(0) / out.shape[0]
@@ -108,7 +108,7 @@ def reset(self):
108108
self.step_store = self.get_empty_store()
109109
self.attention_store = {}
110110

111-
def __init__(self, attn_res=16):
111+
def __init__(self, attn_res):
112112
"""
113113
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
114114
process
@@ -715,7 +715,7 @@ def __call__(
715715
max_iter_to_alter: int = 25,
716716
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
717717
scale_factor: int = 20,
718-
attn_res: int = 16,
718+
attn_res: Optional[Tuple[int]] = None,
719719
):
720720
r"""
721721
Function invoked when calling the pipeline for generation.
@@ -787,8 +787,8 @@ def __call__(
787787
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
788788
scale_factor (`int`, *optional*, default to 20):
789789
Scale factor that controls the step size of each Attend and Excite update.
790-
attn_res (`int`, *optional*, default to 16):
791-
The resolution of most semantic attention map.
790+
attn_res (`tuple`, *optional*, default computed from width and height):
791+
The 2D resolution of the semantic attention map.
792792
793793
Examples:
794794
@@ -861,7 +861,9 @@ def __call__(
861861
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
862862
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
863863

864-
self.attention_store = AttentionStore(attn_res=attn_res)
864+
if attn_res is None:
865+
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
866+
self.attention_store = AttentionStore(attn_res)
865867
self.register_attention_control()
866868

867869
# default config for step size from original repo

0 commit comments

Comments
 (0)