-
Notifications
You must be signed in to change notification settings - Fork 6.6k
add self.use_ada_layer_norm_* params back to BasicTransformerBlock
#6841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -538,7 +538,7 @@ def hack_CrossAttnDownBlock2D_forward( | |
|
|
||
| return hidden_states, output_states | ||
|
|
||
| def hacked_DownBlock2D_forward(self, hidden_states, temb=None): | ||
| def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): | ||
| eps = 1e-6 | ||
|
|
||
| output_states = () | ||
|
|
@@ -634,7 +634,9 @@ def hacked_CrossAttnUpBlock2D_forward( | |
|
|
||
| return hidden_states | ||
|
|
||
| def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): | ||
| def hacked_UpBlock2D_forward( | ||
| self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. |
||
| ): | ||
| eps = 1e-6 | ||
| for i, resnet in enumerate(self.resnets): | ||
| # pop res hidden states | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward( | |
|
|
||
| return hidden_states, output_states | ||
|
|
||
| def hacked_DownBlock2D_forward(self, hidden_states, temb=None): | ||
| def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): | ||
| eps = 1e-6 | ||
|
|
||
| output_states = () | ||
|
|
@@ -603,7 +603,9 @@ def hacked_CrossAttnUpBlock2D_forward( | |
|
|
||
| return hidden_states | ||
|
|
||
| def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): | ||
| def hacked_UpBlock2D_forward( | ||
| self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. |
||
| ): | ||
| eps = 1e-6 | ||
| for i, resnet in enumerate(self.resnets): | ||
| # pop res hidden states | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,6 +158,12 @@ def __init__( | |
| super().__init__() | ||
| self.only_cross_attention = only_cross_attention | ||
|
|
||
| self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" | ||
| self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" | ||
| self.use_ada_layer_norm_single = norm_type == "ada_norm_single" | ||
| self.use_layer_norm = norm_type == "layer_norm" | ||
| self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" | ||
|
Comment on lines
+161
to
+165
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with this.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply. I'd maybe just add a comment to state that they are kept for back-compatibility reasons. |
||
|
|
||
| if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: | ||
| raise ValueError( | ||
| f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is
kwargsneeded?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hacked_DownBlock2D_forwardis a function they wrote to replace theforwardmethod of theDownBlock2D,hence the signature has to match. We added a new argumentscalefor the lora refactor and causes an error here without**kwargsI think they should write their custom blocks instead