14
14
15
15
import inspect
16
16
import math
17
- from typing import Any , Callable , Dict , List , Optional , Union
17
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
18
19
19
import numpy as np
20
20
import torch
@@ -75,7 +75,7 @@ def get_empty_store():
75
75
76
76
def __call__ (self , attn , is_cross : bool , place_in_unet : str ):
77
77
if self .cur_att_layer >= 0 and is_cross :
78
- if attn .shape [1 ] == self .attn_res ** 2 :
78
+ if attn .shape [1 ] == np . prod ( self .attn_res ) :
79
79
self .step_store [place_in_unet ].append (attn )
80
80
81
81
self .cur_att_layer += 1
@@ -97,7 +97,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
97
97
attention_maps = self .get_average_attention ()
98
98
for location in from_where :
99
99
for item in attention_maps [location ]:
100
- cross_maps = item .reshape (- 1 , self .attn_res , self .attn_res , item .shape [- 1 ])
100
+ cross_maps = item .reshape (- 1 , self .attn_res [ 0 ] , self .attn_res [ 1 ] , item .shape [- 1 ])
101
101
out .append (cross_maps )
102
102
out = torch .cat (out , dim = 0 )
103
103
out = out .sum (0 ) / out .shape [0 ]
@@ -108,7 +108,7 @@ def reset(self):
108
108
self .step_store = self .get_empty_store ()
109
109
self .attention_store = {}
110
110
111
- def __init__ (self , attn_res = 16 ):
111
+ def __init__ (self , attn_res ):
112
112
"""
113
113
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
114
114
process
@@ -715,7 +715,7 @@ def __call__(
715
715
max_iter_to_alter : int = 25 ,
716
716
thresholds : dict = {0 : 0.05 , 10 : 0.5 , 20 : 0.8 },
717
717
scale_factor : int = 20 ,
718
- attn_res : int = 16 ,
718
+ attn_res : Optional [ Tuple [ int ]] = None ,
719
719
):
720
720
r"""
721
721
Function invoked when calling the pipeline for generation.
@@ -787,8 +787,8 @@ def __call__(
787
787
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
788
788
scale_factor (`int`, *optional*, default to 20):
789
789
Scale factor that controls the step size of each Attend and Excite update.
790
- attn_res (`int `, *optional*, default to 16 ):
791
- The resolution of most semantic attention map.
790
+ attn_res (`tuple `, *optional*, default computed from width and height ):
791
+ The 2D resolution of the semantic attention map.
792
792
793
793
Examples:
794
794
@@ -861,7 +861,9 @@ def __call__(
861
861
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
862
862
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
863
863
864
- self .attention_store = AttentionStore (attn_res = attn_res )
864
+ if attn_res is None :
865
+ attn_res = int (np .ceil (width / 32 )), int (np .ceil (height / 32 ))
866
+ self .attention_store = AttentionStore (attn_res )
865
867
self .register_attention_control ()
866
868
867
869
# default config for step size from original repo
0 commit comments