Skip to content

Commit c28a3f6

Browse files
committed
Merge branch 'main' of https://github.com/huggingface/diffusers into diffedit-inpainting-pipeline
2 parents a5986a9 + fc18839 commit c28a3f6

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(
248248
if class_embed_type is None and num_class_embeds is not None:
249249
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
250250
elif class_embed_type == "timestep":
251-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
251+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
252252
elif class_embed_type == "identity":
253253
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
254254
elif class_embed_type == "projection":
@@ -437,7 +437,18 @@ def __init__(
437437
self.conv_norm_out = nn.GroupNorm(
438438
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
439439
)
440-
self.conv_act = nn.SiLU()
440+
441+
if act_fn == "swish":
442+
self.conv_act = lambda x: F.silu(x)
443+
elif act_fn == "mish":
444+
self.conv_act = nn.Mish()
445+
elif act_fn == "silu":
446+
self.conv_act = nn.SiLU()
447+
elif act_fn == "gelu":
448+
self.conv_act = nn.GELU()
449+
else:
450+
raise ValueError(f"Unsupported activation function: {act_fn}")
451+
441452
else:
442453
self.conv_norm_out = None
443454
self.conv_act = None
@@ -648,7 +659,7 @@ def forward(
648659

649660
t_emb = self.time_proj(timesteps)
650661

651-
# timesteps does not contain any weights and will always return f32 tensors
662+
# `Timesteps` does not contain any weights and will always return f32 tensors
652663
# but time_embedding might actually be running in fp16. so we need to cast here.
653664
# there might be better ways to encapsulate this.
654665
t_emb = t_emb.to(dtype=self.dtype)
@@ -662,6 +673,10 @@ def forward(
662673
if self.config.class_embed_type == "timestep":
663674
class_labels = self.time_proj(class_labels)
664675

676+
# `Timesteps` does not contain any weights and will always return f32 tensors
677+
# there might be better ways to encapsulate this.
678+
class_labels = class_labels.to(dtype=sample.dtype)
679+
665680
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
666681

667682
if self.config.class_embeddings_concat:

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def __init__(
345345
if class_embed_type is None and num_class_embeds is not None:
346346
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
347347
elif class_embed_type == "timestep":
348-
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
348+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
349349
elif class_embed_type == "identity":
350350
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
351351
elif class_embed_type == "projection":
@@ -534,7 +534,18 @@ def __init__(
534534
self.conv_norm_out = nn.GroupNorm(
535535
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
536536
)
537-
self.conv_act = nn.SiLU()
537+
538+
if act_fn == "swish":
539+
self.conv_act = lambda x: F.silu(x)
540+
elif act_fn == "mish":
541+
self.conv_act = nn.Mish()
542+
elif act_fn == "silu":
543+
self.conv_act = nn.SiLU()
544+
elif act_fn == "gelu":
545+
self.conv_act = nn.GELU()
546+
else:
547+
raise ValueError(f"Unsupported activation function: {act_fn}")
548+
538549
else:
539550
self.conv_norm_out = None
540551
self.conv_act = None
@@ -745,7 +756,7 @@ def forward(
745756

746757
t_emb = self.time_proj(timesteps)
747758

748-
# timesteps does not contain any weights and will always return f32 tensors
759+
# `Timesteps` does not contain any weights and will always return f32 tensors
749760
# but time_embedding might actually be running in fp16. so we need to cast here.
750761
# there might be better ways to encapsulate this.
751762
t_emb = t_emb.to(dtype=self.dtype)
@@ -759,6 +770,10 @@ def forward(
759770
if self.config.class_embed_type == "timestep":
760771
class_labels = self.time_proj(class_labels)
761772

773+
# `Timesteps` does not contain any weights and will always return f32 tensors
774+
# there might be better ways to encapsulate this.
775+
class_labels = class_labels.to(dtype=sample.dtype)
776+
762777
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
763778

764779
if self.config.class_embeddings_concat:

0 commit comments

Comments
 (0)