@@ -165,15 +165,15 @@ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atte
165165
166166 def forward (self , hidden_states , context = None ):
167167 # note: if no context is given, cross-attention defaults to self-attention
168- batch , channel , height , weight = hidden_states .shape
168+ batch , channel , height , width = hidden_states .shape
169169 residual = hidden_states
170170 hidden_states = self .norm (hidden_states )
171171 hidden_states = self .proj_in (hidden_states )
172172 inner_dim = hidden_states .shape [1 ]
173- hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , inner_dim )
173+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * width , inner_dim )
174174 for block in self .transformer_blocks :
175175 hidden_states = block (hidden_states , context = context )
176- hidden_states = hidden_states .reshape (batch , height , weight , inner_dim ).permute (0 , 3 , 1 , 2 )
176+ hidden_states = hidden_states .reshape (batch , height , width , inner_dim ).permute (0 , 3 , 1 , 2 )
177177 hidden_states = self .proj_out (hidden_states )
178178 return hidden_states + residual
179179
0 commit comments