diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5538a7b8249d..1085c452b076 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -60,7 +60,6 @@ def __init__( self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers @@ -74,18 +73,25 @@ def __init__( self._use_memory_efficient_attention_xformers = False self._attention_op = None - def reshape_heads_to_batch_dim(self, tensor): + def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + if merge_head_and_batch: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape + def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + + if unmerge_head_and_batch: + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + else: + batch_size, _, seq_len, dim = tensor.shape + + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size) return tensor def set_use_memory_efficient_attention_xformers( @@ -134,14 +140,25 @@ def forward(self, hidden_states): scale = 1 / math.sqrt(self.channels / self.num_heads) - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) + use_torch_2_0_attn = ( + hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers + ) + + query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) + key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) + value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn) if self._use_memory_efficient_attention_xformers: # Memory efficient attention hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op + query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale + ) + hidden_states = hidden_states.to(query_proj.dtype) + elif use_torch_2_0_attn: + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.to(query_proj.dtype) else: @@ -162,7 +179,7 @@ def forward(self, hidden_states): hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index fe0041850bb4..6cb71bebb9c0 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -319,6 +319,40 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice): assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + @parameterized.expand([13, 16, 27]) + @require_torch_gpu + def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): + model = self.get_sd_vae_model(fp16=True) + encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) + + with torch.no_grad(): + sample = model.decode(encoding).sample + + model.enable_xformers_memory_efficient_attention() + with torch.no_grad(): + sample_2 = model.decode(encoding).sample + + assert list(sample.shape) == [3, 3, 512, 512] + + assert torch_all_close(sample, sample_2, atol=1e-1) + + @parameterized.expand([13, 16, 37]) + @require_torch_gpu + def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): + model = self.get_sd_vae_model() + encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) + + with torch.no_grad(): + sample = model.decode(encoding).sample + + model.enable_xformers_memory_efficient_attention() + with torch.no_grad(): + sample_2 = model.decode(encoding).sample + + assert list(sample.shape) == [3, 3, 512, 512] + + assert torch_all_close(sample, sample_2, atol=1e-2) + @parameterized.expand( [ # fmt: off