@@ -248,7 +248,7 @@ def __init__(
248
248
if class_embed_type is None and num_class_embeds is not None :
249
249
self .class_embedding = nn .Embedding (num_class_embeds , time_embed_dim )
250
250
elif class_embed_type == "timestep" :
251
- self .class_embedding = TimestepEmbedding (timestep_input_dim , time_embed_dim )
251
+ self .class_embedding = TimestepEmbedding (timestep_input_dim , time_embed_dim , act_fn = act_fn )
252
252
elif class_embed_type == "identity" :
253
253
self .class_embedding = nn .Identity (time_embed_dim , time_embed_dim )
254
254
elif class_embed_type == "projection" :
@@ -437,7 +437,18 @@ def __init__(
437
437
self .conv_norm_out = nn .GroupNorm (
438
438
num_channels = block_out_channels [0 ], num_groups = norm_num_groups , eps = norm_eps
439
439
)
440
- self .conv_act = nn .SiLU ()
440
+
441
+ if act_fn == "swish" :
442
+ self .conv_act = lambda x : F .silu (x )
443
+ elif act_fn == "mish" :
444
+ self .conv_act = nn .Mish ()
445
+ elif act_fn == "silu" :
446
+ self .conv_act = nn .SiLU ()
447
+ elif act_fn == "gelu" :
448
+ self .conv_act = nn .GELU ()
449
+ else :
450
+ raise ValueError (f"Unsupported activation function: { act_fn } " )
451
+
441
452
else :
442
453
self .conv_norm_out = None
443
454
self .conv_act = None
@@ -648,7 +659,7 @@ def forward(
648
659
649
660
t_emb = self .time_proj (timesteps )
650
661
651
- # timesteps does not contain any weights and will always return f32 tensors
662
+ # `Timesteps` does not contain any weights and will always return f32 tensors
652
663
# but time_embedding might actually be running in fp16. so we need to cast here.
653
664
# there might be better ways to encapsulate this.
654
665
t_emb = t_emb .to (dtype = self .dtype )
@@ -662,6 +673,10 @@ def forward(
662
673
if self .config .class_embed_type == "timestep" :
663
674
class_labels = self .time_proj (class_labels )
664
675
676
+ # `Timesteps` does not contain any weights and will always return f32 tensors
677
+ # there might be better ways to encapsulate this.
678
+ class_labels = class_labels .to (dtype = sample .dtype )
679
+
665
680
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
666
681
667
682
if self .config .class_embeddings_concat :
0 commit comments