@@ -249,13 +249,15 @@ def reshape_batch_dim_to_heads(self, tensor):
249249 return tensor
250250
251251 def forward (self , hidden_states , context = None , mask = None ):
252- batch_size , sequence_length , dim = hidden_states .shape
252+ batch_size , sequence_length , _ = hidden_states .shape
253253
254254 query = self .to_q (hidden_states )
255255 context = context if context is not None else hidden_states
256256 key = self .to_k (context )
257257 value = self .to_v (context )
258258
259+ dim = query .shape [- 1 ]
260+
259261 query = self .reshape_heads_to_batch_dim (query )
260262 key = self .reshape_heads_to_batch_dim (key )
261263 value = self .reshape_heads_to_batch_dim (value )
@@ -283,7 +285,7 @@ def _attention(self, query, key, value):
283285 def _sliced_attention (self , query , key , value , sequence_length , dim ):
284286 batch_size_attention = query .shape [0 ]
285287 hidden_states = torch .zeros (
286- (batch_size_attention , sequence_length , dim // self . heads ), device = query .device , dtype = query .dtype
288+ (batch_size_attention , sequence_length , dim ), device = query .device , dtype = query .dtype
287289 )
288290 slice_size = self ._slice_size if self ._slice_size is not None else hidden_states .shape [0 ]
289291 for i in range (hidden_states .shape [0 ] // slice_size ):
0 commit comments