Skip to content

Commit 6ea2d26

Browse files
authored
Allow SD attend and excite pipeline to work with any size output images (huggingface#2835)
Allow stable diffusion attend and excite pipeline to work with any size output image. Re: huggingface#2476, huggingface#2603
1 parent 77e3f42 commit 6ea2d26

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
import math
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy as np
2020
import torch
@@ -76,7 +76,7 @@ def get_empty_store():
7676

7777
def __call__(self, attn, is_cross: bool, place_in_unet: str):
7878
if self.cur_att_layer >= 0 and is_cross:
79-
if attn.shape[1] == self.attn_res**2:
79+
if attn.shape[1] == np.prod(self.attn_res):
8080
self.step_store[place_in_unet].append(attn)
8181

8282
self.cur_att_layer += 1
@@ -98,7 +98,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
9898
attention_maps = self.get_average_attention()
9999
for location in from_where:
100100
for item in attention_maps[location]:
101-
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
101+
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
102102
out.append(cross_maps)
103103
out = torch.cat(out, dim=0)
104104
out = out.sum(0) / out.shape[0]
@@ -109,7 +109,7 @@ def reset(self):
109109
self.step_store = self.get_empty_store()
110110
self.attention_store = {}
111111

112-
def __init__(self, attn_res=16):
112+
def __init__(self, attn_res):
113113
"""
114114
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
115115
process
@@ -724,7 +724,7 @@ def __call__(
724724
max_iter_to_alter: int = 25,
725725
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
726726
scale_factor: int = 20,
727-
attn_res: int = 16,
727+
attn_res: Optional[Tuple[int]] = None,
728728
):
729729
r"""
730730
Function invoked when calling the pipeline for generation.
@@ -796,8 +796,8 @@ def __call__(
796796
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
797797
scale_factor (`int`, *optional*, default to 20):
798798
Scale factor that controls the step size of each Attend and Excite update.
799-
attn_res (`int`, *optional*, default to 16):
800-
The resolution of most semantic attention map.
799+
attn_res (`tuple`, *optional*, default computed from width and height):
800+
The 2D resolution of the semantic attention map.
801801
802802
Examples:
803803
@@ -870,7 +870,9 @@ def __call__(
870870
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
871871
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
872872

873-
self.attention_store = AttentionStore(attn_res=attn_res)
873+
if attn_res is None:
874+
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
875+
self.attention_store = AttentionStore(attn_res)
874876
self.register_attention_control()
875877

876878
# default config for step size from original repo

0 commit comments

Comments
 (0)