@@ -98,7 +98,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
9898 def __init__ (
9999 self ,
100100 sample_size : Optional [int ] = None ,
101- in_channels : int = 4 ,
101+ in_channels : int = 8 ,
102102 out_channels : int = 4 ,
103103 down_block_types : Tuple [str , ...] = (
104104 "CrossAttnDownBlock3D" ,
@@ -161,7 +161,7 @@ def __init__(
161161 conv_out_kernel = 3
162162 conv_in_padding = (conv_in_kernel - 1 ) // 2
163163 self .conv_in = nn .Conv2d (
164- in_channels + in_channels , block_out_channels [0 ], kernel_size = conv_in_kernel , padding = conv_in_padding
164+ in_channels , block_out_channels [0 ], kernel_size = conv_in_kernel , padding = conv_in_padding
165165 )
166166
167167 self .transformer_in = TransformerTemporalModel (
@@ -174,28 +174,28 @@ def __init__(
174174
175175 # image embedding
176176 self .local_image_concat = nn .Sequential (
177- nn .Conv2d (4 , in_channels * 4 , 3 , padding = 1 ),
177+ nn .Conv2d (4 , in_channels * 2 , 3 , padding = 1 ),
178178 nn .SiLU (),
179- nn .Conv2d (in_channels * 4 , in_channels * 4 , 3 , stride = 1 , padding = 1 ),
179+ nn .Conv2d (in_channels * 2 , in_channels * 2 , 3 , stride = 1 , padding = 1 ),
180180 nn .SiLU (),
181- nn .Conv2d (in_channels * 4 , in_channels , 3 , stride = 1 , padding = 1 ),
181+ nn .Conv2d (in_channels * 2 , in_channels // 2 , 3 , stride = 1 , padding = 1 ),
182182 )
183183 # print("local_image_concat parameters", sum(p.numel() for p in self.local_image_concat.parameters() if p.requires_grad))
184184 self .local_temporal_encoder = BasicTransformerBlock (
185185 norm_type = "layer_norm_i2vgen" ,
186- dim = in_channels ,
186+ dim = in_channels // 2 ,
187187 num_attention_heads = 2 ,
188- ff_inner_dim = in_channels * 4 ,
189- attention_head_dim = in_channels ,
188+ ff_inner_dim = in_channels * 2 ,
189+ attention_head_dim = in_channels // 2 ,
190190 activation_fn = "gelu" ,
191191 )
192192 self .local_image_embedding = nn .Sequential (
193- nn .Conv2d (4 , in_channels * 8 , 3 , padding = 1 ),
193+ nn .Conv2d (4 , in_channels * 4 , 3 , padding = 1 ),
194194 nn .SiLU (),
195195 nn .AdaptiveAvgPool2d ((32 , 32 )),
196- nn .Conv2d (in_channels * 8 , in_channels * 16 , 3 , stride = 2 , padding = 1 ),
196+ nn .Conv2d (in_channels * 4 , in_channels * 8 , 3 , stride = 2 , padding = 1 ),
197197 nn .SiLU (),
198- nn .Conv2d (in_channels * 16 , 1024 , 3 , stride = 2 , padding = 1 ),
198+ nn .Conv2d (in_channels * 8 , cross_attention_dim , 3 , stride = 2 , padding = 1 ),
199199 )
200200 # print("local_image_embedding parameters", sum(p.numel() for p in self.local_image_embedding.parameters() if p.requires_grad))
201201
@@ -212,7 +212,7 @@ def __init__(
212212 self .context_embedding = nn .Sequential (
213213 nn .Linear (cross_attention_dim , time_embed_dim ),
214214 nn .SiLU (),
215- nn .Linear (time_embed_dim , cross_attention_dim * in_channels ),
215+ nn .Linear (time_embed_dim , cross_attention_dim * in_channels // 2 ),
216216 )
217217 # print("context_embedding parameters", sum(p.numel() for p in self.context_embedding.parameters() if p.requires_grad))
218218 self .fps_embedding = nn .Sequential (
0 commit comments