@@ -75,7 +75,7 @@ def get_empty_store():
75
75
76
76
def __call__ (self , attn , is_cross : bool , place_in_unet : str ):
77
77
if self .cur_att_layer >= 0 and is_cross :
78
- if attn .shape [1 ] == self .attn_res ** 2 :
78
+ if attn .shape [1 ] == np . prod ( self .attn_res ) :
79
79
self .step_store [place_in_unet ].append (attn )
80
80
81
81
self .cur_att_layer += 1
@@ -97,7 +97,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
97
97
attention_maps = self .get_average_attention ()
98
98
for location in from_where :
99
99
for item in attention_maps [location ]:
100
- cross_maps = item .reshape (- 1 , self .attn_res , self .attn_res , item .shape [- 1 ])
100
+ cross_maps = item .reshape (- 1 , self .attn_res [ 0 ] , self .attn_res [ 1 ] , item .shape [- 1 ])
101
101
out .append (cross_maps )
102
102
out = torch .cat (out , dim = 0 )
103
103
out = out .sum (0 ) / out .shape [0 ]
@@ -108,7 +108,7 @@ def reset(self):
108
108
self .step_store = self .get_empty_store ()
109
109
self .attention_store = {}
110
110
111
- def __init__ (self , attn_res = 16 ):
111
+ def __init__ (self , width , height ):
112
112
"""
113
113
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
114
114
process
@@ -118,7 +118,8 @@ def __init__(self, attn_res=16):
118
118
self .step_store = self .get_empty_store ()
119
119
self .attention_store = {}
120
120
self .curr_step_index = 0
121
- self .attn_res = attn_res
121
+ self .attn_res = int (np .ceil (width / 32 )), int (np .ceil (height / 32 ))
122
+
122
123
123
124
124
125
class AttendExciteAttnProcessor :
@@ -715,7 +716,6 @@ def __call__(
715
716
max_iter_to_alter : int = 25 ,
716
717
thresholds : dict = {0 : 0.05 , 10 : 0.5 , 20 : 0.8 },
717
718
scale_factor : int = 20 ,
718
- attn_res : int = 16 ,
719
719
):
720
720
r"""
721
721
Function invoked when calling the pipeline for generation.
@@ -787,8 +787,6 @@ def __call__(
787
787
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
788
788
scale_factor (`int`, *optional*, default to 20):
789
789
Scale factor that controls the step size of each Attend and Excite update.
790
- attn_res (`int`, *optional*, default to 16):
791
- The resolution of most semantic attention map.
792
790
793
791
Examples:
794
792
@@ -861,7 +859,7 @@ def __call__(
861
859
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
862
860
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
863
861
864
- self .attention_store = AttentionStore (attn_res = attn_res )
862
+ self .attention_store = AttentionStore (width , height )
865
863
self .register_attention_control ()
866
864
867
865
# default config for step size from original repo
0 commit comments