@@ -60,7 +60,6 @@ def __init__(
60
60
self .channels = channels
61
61
62
62
self .num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
- self .num_head_size = num_head_channels
64
63
self .group_norm = nn .GroupNorm (num_channels = channels , num_groups = norm_num_groups , eps = eps , affine = True )
65
64
66
65
# define q,k,v as linear layers
@@ -74,18 +73,25 @@ def __init__(
74
73
self ._use_memory_efficient_attention_xformers = False
75
74
self ._attention_op = None
76
75
77
- def reshape_heads_to_batch_dim (self , tensor ):
76
+ def reshape_heads_to_batch_dim (self , tensor , merge_head_and_batch = True ):
78
77
batch_size , seq_len , dim = tensor .shape
79
78
head_size = self .num_heads
80
79
tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
81
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
80
+ tensor = tensor .permute (0 , 2 , 1 , 3 )
81
+ if merge_head_and_batch :
82
+ tensor = tensor .reshape (batch_size * head_size , seq_len , dim // head_size )
82
83
return tensor
83
84
84
- def reshape_batch_dim_to_heads (self , tensor ):
85
- batch_size , seq_len , dim = tensor .shape
85
+ def reshape_batch_dim_to_heads (self , tensor , unmerge_head_and_batch = True ):
86
86
head_size = self .num_heads
87
- tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
88
- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
87
+
88
+ if unmerge_head_and_batch :
89
+ batch_size , seq_len , dim = tensor .shape
90
+ tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
91
+ else :
92
+ batch_size , _ , seq_len , dim = tensor .shape
93
+
94
+ tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size , seq_len , dim * head_size )
89
95
return tensor
90
96
91
97
def set_use_memory_efficient_attention_xformers (
@@ -134,14 +140,25 @@ def forward(self, hidden_states):
134
140
135
141
scale = 1 / math .sqrt (self .channels / self .num_heads )
136
142
137
- query_proj = self .reshape_heads_to_batch_dim (query_proj )
138
- key_proj = self .reshape_heads_to_batch_dim (key_proj )
139
- value_proj = self .reshape_heads_to_batch_dim (value_proj )
143
+ use_torch_2_0_attn = (
144
+ hasattr (F , "scaled_dot_product_attention" ) and not self ._use_memory_efficient_attention_xformers
145
+ )
146
+
147
+ query_proj = self .reshape_heads_to_batch_dim (query_proj , merge_head_and_batch = not use_torch_2_0_attn )
148
+ key_proj = self .reshape_heads_to_batch_dim (key_proj , merge_head_and_batch = not use_torch_2_0_attn )
149
+ value_proj = self .reshape_heads_to_batch_dim (value_proj , merge_head_and_batch = not use_torch_2_0_attn )
140
150
141
151
if self ._use_memory_efficient_attention_xformers :
142
152
# Memory efficient attention
143
153
hidden_states = xformers .ops .memory_efficient_attention (
144
- query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op
154
+ query_proj , key_proj , value_proj , attn_bias = None , op = self ._attention_op , scale = scale
155
+ )
156
+ hidden_states = hidden_states .to (query_proj .dtype )
157
+ elif use_torch_2_0_attn :
158
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
159
+ # TODO: add support for attn.scale when we move to Torch 2.1
160
+ hidden_states = F .scaled_dot_product_attention (
161
+ query_proj , key_proj , value_proj , dropout_p = 0.0 , is_causal = False
145
162
)
146
163
hidden_states = hidden_states .to (query_proj .dtype )
147
164
else :
@@ -162,7 +179,7 @@ def forward(self, hidden_states):
162
179
hidden_states = torch .bmm (attention_probs , value_proj )
163
180
164
181
# reshape hidden_states
165
- hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
182
+ hidden_states = self .reshape_batch_dim_to_heads (hidden_states , unmerge_head_and_batch = not use_torch_2_0_attn )
166
183
167
184
# compute next hidden_states
168
185
hidden_states = self .proj_attn (hidden_states )
0 commit comments