diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 88a7febae650..924548b35ca3 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -538,7 +538,7 @@ def hack_CrossAttnDownBlock2D_forward( return hidden_states, output_states - def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): eps = 1e-6 output_states = () @@ -634,7 +634,9 @@ def hacked_CrossAttnUpBlock2D_forward( return hidden_states - def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def hacked_UpBlock2D_forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs + ): eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index fbfb6bdd6160..4c7efa4b5f7a 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward( return hidden_states, output_states - def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): eps = 1e-6 output_states = () @@ -603,7 +603,9 @@ def hacked_CrossAttnUpBlock2D_forward( return hidden_states - def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def hacked_UpBlock2D_forward( + self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs + ): eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index afb022c8d612..d4d611250ad0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -158,6 +158,12 @@ def __init__( super().__init__() self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"