Skip to content

Commit e309542

Browse files
committed
flatten conditional
1 parent fcc4c81 commit e309542

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,18 @@ def __init__(
272272
else:
273273
self.class_embedding = None
274274

275-
if time_embedding_act_fn is not None:
276-
if time_embedding_act_fn == "swish":
277-
self.time_embed_act = lambda x: F.silu(x)
278-
elif time_embedding_act_fn == "mish":
279-
self.time_embed_act = nn.Mish()
280-
elif time_embedding_act_fn == "silu":
281-
self.time_embed_act = nn.SiLU()
282-
elif time_embedding_act_fn == "gelu":
283-
self.time_embed_act = nn.GELU()
284-
else:
285-
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
286-
else:
275+
if time_embedding_act_fn is None:
287276
self.time_embed_act = None
277+
elif time_embedding_act_fn == "swish":
278+
self.time_embed_act = lambda x: F.silu(x)
279+
elif time_embedding_act_fn == "mish":
280+
self.time_embed_act = nn.Mish()
281+
elif time_embedding_act_fn == "silu":
282+
self.time_embed_act = nn.SiLU()
283+
elif time_embedding_act_fn == "gelu":
284+
self.time_embed_act = nn.GELU()
285+
else:
286+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
288287

289288
self.down_blocks = nn.ModuleList([])
290289
self.up_blocks = nn.ModuleList([])

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -364,19 +364,18 @@ def __init__(
364364
else:
365365
self.class_embedding = None
366366

367-
if time_embedding_act_fn is not None:
368-
if time_embedding_act_fn == "swish":
369-
self.time_embed_act = lambda x: F.silu(x)
370-
elif time_embedding_act_fn == "mish":
371-
self.time_embed_act = nn.Mish()
372-
elif time_embedding_act_fn == "silu":
373-
self.time_embed_act = nn.SiLU()
374-
elif time_embedding_act_fn == "gelu":
375-
self.time_embed_act = nn.GELU()
376-
else:
377-
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
378-
else:
367+
if time_embedding_act_fn is None:
379368
self.time_embed_act = None
369+
elif time_embedding_act_fn == "swish":
370+
self.time_embed_act = lambda x: F.silu(x)
371+
elif time_embedding_act_fn == "mish":
372+
self.time_embed_act = nn.Mish()
373+
elif time_embedding_act_fn == "silu":
374+
self.time_embed_act = nn.SiLU()
375+
elif time_embedding_act_fn == "gelu":
376+
self.time_embed_act = nn.GELU()
377+
else:
378+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
380379

381380
self.down_blocks = nn.ModuleList([])
382381
self.up_blocks = nn.ModuleList([])

0 commit comments

Comments
 (0)