diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b2814356939b..29de8734d4e7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -248,7 +248,7 @@ def __init__( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -437,7 +437,18 @@ def __init__( self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4377be1181a8..b20f18c485d0 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -345,7 +345,7 @@ def __init__( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -534,7 +534,18 @@ def __init__( self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None