Skip to content

Commit f0c74e9

Browse files
Add unet act fn to other model components (#3136)
Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior.
1 parent 4bc157f commit f0c74e9

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(
248248
if class_embed_type is None and num_class_embeds is not None:
249249
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
250250
elif class_embed_type == "timestep":
251-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
251+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
252252
elif class_embed_type == "identity":
253253
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
254254
elif class_embed_type == "projection":
@@ -437,7 +437,18 @@ def __init__(
437437
self.conv_norm_out = nn.GroupNorm(
438438
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
439439
)
440-
self.conv_act = nn.SiLU()
440+
441+
if act_fn == "swish":
442+
self.conv_act = lambda x: F.silu(x)
443+
elif act_fn == "mish":
444+
self.conv_act = nn.Mish()
445+
elif act_fn == "silu":
446+
self.conv_act = nn.SiLU()
447+
elif act_fn == "gelu":
448+
self.conv_act = nn.GELU()
449+
else:
450+
raise ValueError(f"Unsupported activation function: {act_fn}")
451+
441452
else:
442453
self.conv_norm_out = None
443454
self.conv_act = None

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def __init__(
345345
if class_embed_type is None and num_class_embeds is not None:
346346
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
347347
elif class_embed_type == "timestep":
348-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
348+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
349349
elif class_embed_type == "identity":
350350
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
351351
elif class_embed_type == "projection":
@@ -534,7 +534,18 @@ def __init__(
534534
self.conv_norm_out = nn.GroupNorm(
535535
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
536536
)
537-
self.conv_act = nn.SiLU()
537+
538+
if act_fn == "swish":
539+
self.conv_act = lambda x: F.silu(x)
540+
elif act_fn == "mish":
541+
self.conv_act = nn.Mish()
542+
elif act_fn == "silu":
543+
self.conv_act = nn.SiLU()
544+
elif act_fn == "gelu":
545+
self.conv_act = nn.GELU()
546+
else:
547+
raise ValueError(f"Unsupported activation function: {act_fn}")
548+
538549
else:
539550
self.conv_norm_out = None
540551
self.conv_act = None

0 commit comments

Comments
 (0)