14
14
15
15
import inspect
16
16
import math
17
- from typing import Any , Callable , Dict , List , Optional , Union
17
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
18
19
19
import numpy as np
20
20
import torch
@@ -76,7 +76,7 @@ def get_empty_store():
76
76
77
77
def __call__ (self , attn , is_cross : bool , place_in_unet : str ):
78
78
if self .cur_att_layer >= 0 and is_cross :
79
- if attn .shape [1 ] == self .attn_res ** 2 :
79
+ if attn .shape [1 ] == np . prod ( self .attn_res ) :
80
80
self .step_store [place_in_unet ].append (attn )
81
81
82
82
self .cur_att_layer += 1
@@ -98,7 +98,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
98
98
attention_maps = self .get_average_attention ()
99
99
for location in from_where :
100
100
for item in attention_maps [location ]:
101
- cross_maps = item .reshape (- 1 , self .attn_res , self .attn_res , item .shape [- 1 ])
101
+ cross_maps = item .reshape (- 1 , self .attn_res [ 0 ] , self .attn_res [ 1 ] , item .shape [- 1 ])
102
102
out .append (cross_maps )
103
103
out = torch .cat (out , dim = 0 )
104
104
out = out .sum (0 ) / out .shape [0 ]
@@ -109,7 +109,7 @@ def reset(self):
109
109
self .step_store = self .get_empty_store ()
110
110
self .attention_store = {}
111
111
112
- def __init__ (self , attn_res = 16 ):
112
+ def __init__ (self , attn_res ):
113
113
"""
114
114
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
115
115
process
@@ -724,7 +724,7 @@ def __call__(
724
724
max_iter_to_alter : int = 25 ,
725
725
thresholds : dict = {0 : 0.05 , 10 : 0.5 , 20 : 0.8 },
726
726
scale_factor : int = 20 ,
727
- attn_res : int = 16 ,
727
+ attn_res : Optional [ Tuple [ int ]] = None ,
728
728
):
729
729
r"""
730
730
Function invoked when calling the pipeline for generation.
@@ -796,8 +796,8 @@ def __call__(
796
796
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
797
797
scale_factor (`int`, *optional*, default to 20):
798
798
Scale factor that controls the step size of each Attend and Excite update.
799
- attn_res (`int `, *optional*, default to 16 ):
800
- The resolution of most semantic attention map.
799
+ attn_res (`tuple `, *optional*, default computed from width and height ):
800
+ The 2D resolution of the semantic attention map.
801
801
802
802
Examples:
803
803
@@ -870,7 +870,9 @@ def __call__(
870
870
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
871
871
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
872
872
873
- self .attention_store = AttentionStore (attn_res = attn_res )
873
+ if attn_res is None :
874
+ attn_res = int (np .ceil (width / 32 )), int (np .ceil (height / 32 ))
875
+ self .attention_store = AttentionStore (attn_res )
874
876
self .register_attention_control ()
875
877
876
878
# default config for step size from original repo
0 commit comments