@@ -238,16 +238,16 @@ def __init__(
238238
239239 # Check inputs
240240 self ._check_config (
241- down_block_types ,
242- up_block_types ,
243- only_cross_attention ,
244- block_out_channels ,
245- layers_per_block ,
246- cross_attention_dim ,
247- transformer_layers_per_block ,
248- reverse_transformer_layers_per_block ,
249- attention_head_dim ,
250- num_attention_heads ,
241+ down_block_types = down_block_types ,
242+ up_block_types = up_block_types ,
243+ only_cross_attention = only_cross_attention ,
244+ block_out_channels = block_out_channels ,
245+ layers_per_block = layers_per_block ,
246+ cross_attention_dim = cross_attention_dim ,
247+ transformer_layers_per_block = transformer_layers_per_block ,
248+ reverse_transformer_layers_per_block = reverse_transformer_layers_per_block ,
249+ attention_head_dim = attention_head_dim ,
250+ num_attention_heads = num_attention_heads ,
251251 )
252252
253253 # input
@@ -258,7 +258,11 @@ def __init__(
258258
259259 # time
260260 time_embed_dim , timestep_input_dim = self ._set_time_proj (
261- flip_sin_to_cos , freq_shift , block_out_channels , time_embedding_type , time_embedding_dim
261+ time_embedding_type ,
262+ block_out_channels = block_out_channels ,
263+ flip_sin_to_cos = flip_sin_to_cos ,
264+ freq_shift = freq_shift ,
265+ time_embedding_dim = time_embedding_dim ,
262266 )
263267
264268 self .time_embedding = TimestepEmbedding (
@@ -269,28 +273,32 @@ def __init__(
269273 cond_proj_dim = time_cond_proj_dim ,
270274 )
271275
272- self ._set_encoder_hid_proj (cross_attention_dim , encoder_hid_dim , encoder_hid_dim_type )
276+ self ._set_encoder_hid_proj (
277+ encoder_hid_dim_type ,
278+ cross_attention_dim = cross_attention_dim ,
279+ encoder_hid_dim = encoder_hid_dim ,
280+ )
273281
274282 # class embedding
275283 self ._set_class_embedding (
276- act_fn ,
277284 class_embed_type ,
278- num_class_embeds ,
279- projection_class_embeddings_input_dim ,
280- time_embed_dim ,
281- timestep_input_dim ,
285+ act_fn = act_fn ,
286+ num_class_embeds = num_class_embeds ,
287+ projection_class_embeddings_input_dim = projection_class_embeddings_input_dim ,
288+ time_embed_dim = time_embed_dim ,
289+ timestep_input_dim = timestep_input_dim ,
282290 )
283291
284292 self ._set_add_embedding (
285- flip_sin_to_cos ,
286- freq_shift ,
287- cross_attention_dim ,
288- encoder_hid_dim ,
289293 addition_embed_type ,
290- addition_time_embed_dim ,
291- projection_class_embeddings_input_dim ,
292- addition_embed_type_num_heads ,
293- time_embed_dim ,
294+ addition_embed_type_num_heads = addition_embed_type_num_heads ,
295+ addition_time_embed_dim = addition_time_embed_dim ,
296+ cross_attention_dim = cross_attention_dim ,
297+ encoder_hid_dim = encoder_hid_dim ,
298+ flip_sin_to_cos = flip_sin_to_cos ,
299+ freq_shift = freq_shift ,
300+ projection_class_embeddings_input_dim = projection_class_embeddings_input_dim ,
301+ time_embed_dim = time_embed_dim ,
294302 )
295303
296304 if time_embedding_act_fn is None :
@@ -468,20 +476,20 @@ def __init__(
468476 block_out_channels [0 ], out_channels , kernel_size = conv_out_kernel , padding = conv_out_padding
469477 )
470478
471- self ._set_pos_net_if_use_gligen (cross_attention_dim , attention_type )
479+ self ._set_pos_net_if_use_gligen (attention_type = attention_type , cross_attention_dim = cross_attention_dim )
472480
473481 def _check_config (
474482 self ,
475- down_block_types ,
476- up_block_types ,
477- only_cross_attention ,
478- block_out_channels ,
479- layers_per_block ,
480- cross_attention_dim ,
481- transformer_layers_per_block ,
482- reverse_transformer_layers_per_block ,
483- attention_head_dim ,
484- num_attention_heads ,
483+ down_block_types : Tuple [ str ] ,
484+ up_block_types : Tuple [ str ] ,
485+ only_cross_attention : Union [ bool , Tuple [ bool ]] ,
486+ block_out_channels : Tuple [ int ] ,
487+ layers_per_block : [ int , Tuple [ int ]] ,
488+ cross_attention_dim : Union [ int , Tuple [ int ]] ,
489+ transformer_layers_per_block : Union [ int , Tuple [ int ], Tuple [ Tuple ]] ,
490+ reverse_transformer_layers_per_block : bool ,
491+ attention_head_dim : int ,
492+ num_attention_heads : Optional [ Union [ int , Tuple [ int ]]] ,
485493 ):
486494 if len (down_block_types ) != len (up_block_types ):
487495 raise ValueError (
@@ -522,7 +530,14 @@ def _check_config(
522530 if isinstance (layer_number_per_block , list ):
523531 raise ValueError ("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." )
524532
525- def _set_time_proj (self , flip_sin_to_cos , freq_shift , block_out_channels , time_embedding_type , time_embedding_dim ):
533+ def _set_time_proj (
534+ self ,
535+ time_embedding_type : str ,
536+ block_out_channels : int ,
537+ flip_sin_to_cos : bool ,
538+ freq_shift : float ,
539+ time_embedding_dim : int ,
540+ ) -> Tuple [int , int ]:
526541 if time_embedding_type == "fourier" :
527542 time_embed_dim = time_embedding_dim or block_out_channels [0 ] * 2
528543 if time_embed_dim % 2 != 0 :
@@ -543,7 +558,12 @@ def _set_time_proj(self, flip_sin_to_cos, freq_shift, block_out_channels, time_e
543558
544559 return time_embed_dim , timestep_input_dim
545560
546- def _set_encoder_hid_proj (self , cross_attention_dim , encoder_hid_dim , encoder_hid_dim_type ):
561+ def _set_encoder_hid_proj (
562+ self ,
563+ encoder_hid_dim_type : Optional [str ],
564+ cross_attention_dim : Union [int , Tuple [int ]],
565+ encoder_hid_dim : Optional [int ],
566+ ):
547567 if encoder_hid_dim_type is None and encoder_hid_dim is not None :
548568 encoder_hid_dim_type = "text_proj"
549569 self .register_to_config (encoder_hid_dim_type = encoder_hid_dim_type )
@@ -580,12 +600,12 @@ def _set_encoder_hid_proj(self, cross_attention_dim, encoder_hid_dim, encoder_hi
580600
581601 def _set_class_embedding (
582602 self ,
583- act_fn ,
584- class_embed_type ,
585- num_class_embeds ,
586- projection_class_embeddings_input_dim ,
587- time_embed_dim ,
588- timestep_input_dim ,
603+ class_embed_type : Optional [ str ] ,
604+ act_fn : str ,
605+ num_class_embeds : Optional [ int ] ,
606+ projection_class_embeddings_input_dim : Optional [ int ] ,
607+ time_embed_dim : int ,
608+ timestep_input_dim : int ,
589609 ):
590610 if class_embed_type is None and num_class_embeds is not None :
591611 self .class_embedding = nn .Embedding (num_class_embeds , time_embed_dim )
@@ -617,15 +637,15 @@ def _set_class_embedding(
617637
618638 def _set_add_embedding (
619639 self ,
620- flip_sin_to_cos ,
621- freq_shift ,
622- cross_attention_dim ,
623- encoder_hid_dim ,
624- addition_embed_type ,
625- addition_time_embed_dim ,
626- projection_class_embeddings_input_dim ,
627- addition_embed_type_num_heads ,
628- time_embed_dim ,
640+ addition_embed_type : str ,
641+ addition_embed_type_num_heads : int ,
642+ addition_time_embed_dim : Optional [ int ] ,
643+ flip_sin_to_cos : bool ,
644+ freq_shift : float ,
645+ cross_attention_dim : Optional [ int ] ,
646+ encoder_hid_dim : Optional [ int ] ,
647+ projection_class_embeddings_input_dim : Optional [ int ] ,
648+ time_embed_dim : int ,
629649 ):
630650 if addition_embed_type == "text" :
631651 if encoder_hid_dim is not None :
@@ -655,7 +675,7 @@ def _set_add_embedding(
655675 elif addition_embed_type is not None :
656676 raise ValueError (f"addition_embed_type: { addition_embed_type } must be None, 'text' or 'text_image'." )
657677
658- def _set_pos_net_if_use_gligen (self , cross_attention_dim , attention_type ):
678+ def _set_pos_net_if_use_gligen (self , attention_type : str , cross_attention_dim : int ):
659679 if attention_type in ["gated" , "gated-text-image" ]:
660680 positive_len = 768
661681 if isinstance (cross_attention_dim , int ):
@@ -889,7 +909,9 @@ def unload_lora(self):
889909 if hasattr (module , "set_lora_layer" ):
890910 module .set_lora_layer (None )
891911
892- def get_time_embed (self , sample , timestep ):
912+ def get_time_embed (
913+ self , sample : torch .Tensor , timestep : Union [torch .Tensor , float , int ]
914+ ) -> Optional [torch .Tensor ]:
893915 timesteps = timestep
894916 if not torch .is_tensor (timesteps ):
895917 # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
@@ -913,7 +935,7 @@ def get_time_embed(self, sample, timestep):
913935 t_emb = t_emb .to (dtype = sample .dtype )
914936 return t_emb
915937
916- def get_class_embed (self , sample , class_labels ) :
938+ def get_class_embed (self , sample : torch . Tensor , class_labels : Optional [ torch . Tensor ]) -> Optional [ torch . Tensor ] :
917939 class_emb = None
918940 if self .class_embedding is not None :
919941 if class_labels is None :
@@ -929,7 +951,9 @@ def get_class_embed(self, sample, class_labels):
929951 class_emb = self .class_embedding (class_labels ).to (dtype = sample .dtype )
930952 return class_emb
931953
932- def get_aug_embed (self , encoder_hidden_states , added_cond_kwargs , emb ):
954+ def get_aug_embed (
955+ self , emb : torch .Tensor , encoder_hidden_states : torch .Tensor , added_cond_kwargs : Dict
956+ ) -> Optional [torch .Tensor ]:
933957 aug_emb = None
934958 if self .config .addition_embed_type == "text" :
935959 aug_emb = self .add_embedding (encoder_hidden_states )
@@ -979,7 +1003,7 @@ def get_aug_embed(self, encoder_hidden_states, added_cond_kwargs, emb):
9791003 aug_emb = self .add_embedding (image_embs , hint )
9801004 return aug_emb
9811005
982- def process_encoder_hidden_states (self , encoder_hidden_states , added_cond_kwargs ):
1006+ def process_encoder_hidden_states (self , encoder_hidden_states : torch . Tensor , added_cond_kwargs ) -> torch . Tensor :
9831007 if self .encoder_hid_proj is not None and self .config .encoder_hid_dim_type == "text_proj" :
9841008 encoder_hidden_states = self .encoder_hid_proj (encoder_hidden_states )
9851009 elif self .encoder_hid_proj is not None and self .config .encoder_hid_dim_type == "text_image_proj" :
@@ -1121,18 +1145,20 @@ def forward(
11211145 sample = 2 * sample - 1.0
11221146
11231147 # 1. time
1124- t_emb = self .get_time_embed (sample , timestep )
1148+ t_emb = self .get_time_embed (sample = sample , timestep = timestep )
11251149 emb = self .time_embedding (t_emb , timestep_cond )
11261150 aug_emb = None
11271151
1128- class_emb = self .get_class_embed (sample , class_labels )
1152+ class_emb = self .get_class_embed (sample = sample , class_labels = class_labels )
11291153 if class_emb is not None :
11301154 if self .config .class_embeddings_concat :
11311155 emb = torch .cat ([emb , class_emb ], dim = - 1 )
11321156 else :
11331157 emb = emb + class_emb
11341158
1135- aug_emb = self .get_aug_embed (encoder_hidden_states , added_cond_kwargs , emb )
1159+ aug_emb = self .get_aug_embed (
1160+ emb = emb , encoder_hidden_states = encoder_hidden_states , added_cond_kwargs = added_cond_kwargs
1161+ )
11361162 if self .config .addition_embed_type == "image_hint" :
11371163 aug_emb , hint = aug_emb
11381164 sample = torch .cat ([sample , hint ], dim = 1 )
@@ -1141,7 +1167,9 @@ def forward(
11411167 if self .time_embed_act is not None :
11421168 emb = self .time_embed_act (emb )
11431169
1144- encoder_hidden_states = self .process_encoder_hidden_states (encoder_hidden_states , added_cond_kwargs )
1170+ encoder_hidden_states = self .process_encoder_hidden_states (
1171+ encoder_hidden_states = encoder_hidden_states , added_cond_kwargs = added_cond_kwargs
1172+ )
11451173
11461174 # 2. pre-process
11471175 sample = self .conv_in (sample )
0 commit comments