Skip to content

Commit f899f92

Browse files
committed
add type hint and use keyword arguments
1 parent 4001c86 commit f899f92

File tree

1 file changed

+90
-62
lines changed

1 file changed

+90
-62
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 90 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,16 @@ def __init__(
238238

239239
# Check inputs
240240
self._check_config(
241-
down_block_types,
242-
up_block_types,
243-
only_cross_attention,
244-
block_out_channels,
245-
layers_per_block,
246-
cross_attention_dim,
247-
transformer_layers_per_block,
248-
reverse_transformer_layers_per_block,
249-
attention_head_dim,
250-
num_attention_heads,
241+
down_block_types=down_block_types,
242+
up_block_types=up_block_types,
243+
only_cross_attention=only_cross_attention,
244+
block_out_channels=block_out_channels,
245+
layers_per_block=layers_per_block,
246+
cross_attention_dim=cross_attention_dim,
247+
transformer_layers_per_block=transformer_layers_per_block,
248+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
249+
attention_head_dim=attention_head_dim,
250+
num_attention_heads=num_attention_heads,
251251
)
252252

253253
# input
@@ -258,7 +258,11 @@ def __init__(
258258

259259
# time
260260
time_embed_dim, timestep_input_dim = self._set_time_proj(
261-
flip_sin_to_cos, freq_shift, block_out_channels, time_embedding_type, time_embedding_dim
261+
time_embedding_type,
262+
block_out_channels=block_out_channels,
263+
flip_sin_to_cos=flip_sin_to_cos,
264+
freq_shift=freq_shift,
265+
time_embedding_dim=time_embedding_dim,
262266
)
263267

264268
self.time_embedding = TimestepEmbedding(
@@ -269,28 +273,32 @@ def __init__(
269273
cond_proj_dim=time_cond_proj_dim,
270274
)
271275

272-
self._set_encoder_hid_proj(cross_attention_dim, encoder_hid_dim, encoder_hid_dim_type)
276+
self._set_encoder_hid_proj(
277+
encoder_hid_dim_type,
278+
cross_attention_dim=cross_attention_dim,
279+
encoder_hid_dim=encoder_hid_dim,
280+
)
273281

274282
# class embedding
275283
self._set_class_embedding(
276-
act_fn,
277284
class_embed_type,
278-
num_class_embeds,
279-
projection_class_embeddings_input_dim,
280-
time_embed_dim,
281-
timestep_input_dim,
285+
act_fn=act_fn,
286+
num_class_embeds=num_class_embeds,
287+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
288+
time_embed_dim=time_embed_dim,
289+
timestep_input_dim=timestep_input_dim,
282290
)
283291

284292
self._set_add_embedding(
285-
flip_sin_to_cos,
286-
freq_shift,
287-
cross_attention_dim,
288-
encoder_hid_dim,
289293
addition_embed_type,
290-
addition_time_embed_dim,
291-
projection_class_embeddings_input_dim,
292-
addition_embed_type_num_heads,
293-
time_embed_dim,
294+
addition_embed_type_num_heads=addition_embed_type_num_heads,
295+
addition_time_embed_dim=addition_time_embed_dim,
296+
cross_attention_dim=cross_attention_dim,
297+
encoder_hid_dim=encoder_hid_dim,
298+
flip_sin_to_cos=flip_sin_to_cos,
299+
freq_shift=freq_shift,
300+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301+
time_embed_dim=time_embed_dim,
294302
)
295303

296304
if time_embedding_act_fn is None:
@@ -468,20 +476,20 @@ def __init__(
468476
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
469477
)
470478

471-
self._set_pos_net_if_use_gligen(cross_attention_dim, attention_type)
479+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
472480

473481
def _check_config(
474482
self,
475-
down_block_types,
476-
up_block_types,
477-
only_cross_attention,
478-
block_out_channels,
479-
layers_per_block,
480-
cross_attention_dim,
481-
transformer_layers_per_block,
482-
reverse_transformer_layers_per_block,
483-
attention_head_dim,
484-
num_attention_heads,
483+
down_block_types: Tuple[str],
484+
up_block_types: Tuple[str],
485+
only_cross_attention: Union[bool, Tuple[bool]],
486+
block_out_channels: Tuple[int],
487+
layers_per_block: [int, Tuple[int]],
488+
cross_attention_dim: Union[int, Tuple[int]],
489+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
490+
reverse_transformer_layers_per_block: bool,
491+
attention_head_dim: int,
492+
num_attention_heads: Optional[Union[int, Tuple[int]]],
485493
):
486494
if len(down_block_types) != len(up_block_types):
487495
raise ValueError(
@@ -522,7 +530,14 @@ def _check_config(
522530
if isinstance(layer_number_per_block, list):
523531
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
524532

525-
def _set_time_proj(self, flip_sin_to_cos, freq_shift, block_out_channels, time_embedding_type, time_embedding_dim):
533+
def _set_time_proj(
534+
self,
535+
time_embedding_type: str,
536+
block_out_channels: int,
537+
flip_sin_to_cos: bool,
538+
freq_shift: float,
539+
time_embedding_dim: int,
540+
) -> Tuple[int, int]:
526541
if time_embedding_type == "fourier":
527542
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
528543
if time_embed_dim % 2 != 0:
@@ -543,7 +558,12 @@ def _set_time_proj(self, flip_sin_to_cos, freq_shift, block_out_channels, time_e
543558

544559
return time_embed_dim, timestep_input_dim
545560

546-
def _set_encoder_hid_proj(self, cross_attention_dim, encoder_hid_dim, encoder_hid_dim_type):
561+
def _set_encoder_hid_proj(
562+
self,
563+
encoder_hid_dim_type: Optional[str],
564+
cross_attention_dim: Union[int, Tuple[int]],
565+
encoder_hid_dim: Optional[int],
566+
):
547567
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
548568
encoder_hid_dim_type = "text_proj"
549569
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
@@ -580,12 +600,12 @@ def _set_encoder_hid_proj(self, cross_attention_dim, encoder_hid_dim, encoder_hi
580600

581601
def _set_class_embedding(
582602
self,
583-
act_fn,
584-
class_embed_type,
585-
num_class_embeds,
586-
projection_class_embeddings_input_dim,
587-
time_embed_dim,
588-
timestep_input_dim,
603+
class_embed_type: Optional[str],
604+
act_fn: str,
605+
num_class_embeds: Optional[int],
606+
projection_class_embeddings_input_dim: Optional[int],
607+
time_embed_dim: int,
608+
timestep_input_dim: int,
589609
):
590610
if class_embed_type is None and num_class_embeds is not None:
591611
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
@@ -617,15 +637,15 @@ def _set_class_embedding(
617637

618638
def _set_add_embedding(
619639
self,
620-
flip_sin_to_cos,
621-
freq_shift,
622-
cross_attention_dim,
623-
encoder_hid_dim,
624-
addition_embed_type,
625-
addition_time_embed_dim,
626-
projection_class_embeddings_input_dim,
627-
addition_embed_type_num_heads,
628-
time_embed_dim,
640+
addition_embed_type: str,
641+
addition_embed_type_num_heads: int,
642+
addition_time_embed_dim: Optional[int],
643+
flip_sin_to_cos: bool,
644+
freq_shift: float,
645+
cross_attention_dim: Optional[int],
646+
encoder_hid_dim: Optional[int],
647+
projection_class_embeddings_input_dim: Optional[int],
648+
time_embed_dim: int,
629649
):
630650
if addition_embed_type == "text":
631651
if encoder_hid_dim is not None:
@@ -655,7 +675,7 @@ def _set_add_embedding(
655675
elif addition_embed_type is not None:
656676
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
657677

658-
def _set_pos_net_if_use_gligen(self, cross_attention_dim, attention_type):
678+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
659679
if attention_type in ["gated", "gated-text-image"]:
660680
positive_len = 768
661681
if isinstance(cross_attention_dim, int):
@@ -889,7 +909,9 @@ def unload_lora(self):
889909
if hasattr(module, "set_lora_layer"):
890910
module.set_lora_layer(None)
891911

892-
def get_time_embed(self, sample, timestep):
912+
def get_time_embed(
913+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
914+
) -> Optional[torch.Tensor]:
893915
timesteps = timestep
894916
if not torch.is_tensor(timesteps):
895917
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
@@ -913,7 +935,7 @@ def get_time_embed(self, sample, timestep):
913935
t_emb = t_emb.to(dtype=sample.dtype)
914936
return t_emb
915937

916-
def get_class_embed(self, sample, class_labels):
938+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
917939
class_emb = None
918940
if self.class_embedding is not None:
919941
if class_labels is None:
@@ -929,7 +951,9 @@ def get_class_embed(self, sample, class_labels):
929951
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
930952
return class_emb
931953

932-
def get_aug_embed(self, encoder_hidden_states, added_cond_kwargs, emb):
954+
def get_aug_embed(
955+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
956+
) -> Optional[torch.Tensor]:
933957
aug_emb = None
934958
if self.config.addition_embed_type == "text":
935959
aug_emb = self.add_embedding(encoder_hidden_states)
@@ -979,7 +1003,7 @@ def get_aug_embed(self, encoder_hidden_states, added_cond_kwargs, emb):
9791003
aug_emb = self.add_embedding(image_embs, hint)
9801004
return aug_emb
9811005

982-
def process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs):
1006+
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
9831007
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
9841008
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
9851009
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
@@ -1121,18 +1145,20 @@ def forward(
11211145
sample = 2 * sample - 1.0
11221146

11231147
# 1. time
1124-
t_emb = self.get_time_embed(sample, timestep)
1148+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
11251149
emb = self.time_embedding(t_emb, timestep_cond)
11261150
aug_emb = None
11271151

1128-
class_emb = self.get_class_embed(sample, class_labels)
1152+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
11291153
if class_emb is not None:
11301154
if self.config.class_embeddings_concat:
11311155
emb = torch.cat([emb, class_emb], dim=-1)
11321156
else:
11331157
emb = emb + class_emb
11341158

1135-
aug_emb = self.get_aug_embed(encoder_hidden_states, added_cond_kwargs, emb)
1159+
aug_emb = self.get_aug_embed(
1160+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1161+
)
11361162
if self.config.addition_embed_type == "image_hint":
11371163
aug_emb, hint = aug_emb
11381164
sample = torch.cat([sample, hint], dim=1)
@@ -1141,7 +1167,9 @@ def forward(
11411167
if self.time_embed_act is not None:
11421168
emb = self.time_embed_act(emb)
11431169

1144-
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)
1170+
encoder_hidden_states = self.process_encoder_hidden_states(
1171+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1172+
)
11451173

11461174
# 2. pre-process
11471175
sample = self.conv_in(sample)

0 commit comments

Comments
 (0)