Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/community/stable_diffusion_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def hack_CrossAttnDownBlock2D_forward(

return hidden_states, output_states

def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is kwargs needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hacked_DownBlock2D_forward is a function they wrote to replace the forward method of the DownBlock2D ,hence the signature has to match. We added a new argument scale for the lora refactor and causes an error here without **kwargs

I think they should write their custom blocks instead

eps = 1e-6

output_states = ()
Expand Down Expand Up @@ -634,7 +634,9 @@ def hacked_CrossAttnUpBlock2D_forward(

return hidden_states

def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
Expand Down
6 changes: 4 additions & 2 deletions examples/community/stable_diffusion_xl_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward(

return hidden_states, output_states

def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
eps = 1e-6

output_states = ()
Expand Down Expand Up @@ -603,7 +603,9 @@ def hacked_CrossAttnUpBlock2D_forward(

return hidden_states

def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def __init__(
super().__init__()
self.only_cross_attention = only_cross_attention

self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
Comment on lines +161 to +165
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I'd maybe just add a comment to state that they are kept for back-compatibility reasons.


if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
Expand Down