diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3eeb132fe65e..04ead2adcf6e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -81,7 +81,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) else: self.group_norm = None @@ -93,8 +93,8 @@ def __init__( self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))