Skip to content

Commit 896d626

Browse files
committed
better input channels
1 parent 34e7349 commit 896d626

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/diffusers/models/unet_i2vgen_xl.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
9898
def __init__(
9999
self,
100100
sample_size: Optional[int] = None,
101-
in_channels: int = 4,
101+
in_channels: int = 8,
102102
out_channels: int = 4,
103103
down_block_types: Tuple[str, ...] = (
104104
"CrossAttnDownBlock3D",
@@ -161,7 +161,7 @@ def __init__(
161161
conv_out_kernel = 3
162162
conv_in_padding = (conv_in_kernel - 1) // 2
163163
self.conv_in = nn.Conv2d(
164-
in_channels + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
164+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
165165
)
166166

167167
self.transformer_in = TransformerTemporalModel(
@@ -174,28 +174,28 @@ def __init__(
174174

175175
# image embedding
176176
self.local_image_concat = nn.Sequential(
177-
nn.Conv2d(4, in_channels * 4, 3, padding=1),
177+
nn.Conv2d(4, in_channels * 2, 3, padding=1),
178178
nn.SiLU(),
179-
nn.Conv2d(in_channels * 4, in_channels * 4, 3, stride=1, padding=1),
179+
nn.Conv2d(in_channels * 2, in_channels * 2, 3, stride=1, padding=1),
180180
nn.SiLU(),
181-
nn.Conv2d(in_channels * 4, in_channels, 3, stride=1, padding=1),
181+
nn.Conv2d(in_channels * 2, in_channels // 2, 3, stride=1, padding=1),
182182
)
183183
# print("local_image_concat parameters", sum(p.numel() for p in self.local_image_concat.parameters() if p.requires_grad))
184184
self.local_temporal_encoder = BasicTransformerBlock(
185185
norm_type="layer_norm_i2vgen",
186-
dim=in_channels,
186+
dim=in_channels // 2,
187187
num_attention_heads=2,
188-
ff_inner_dim=in_channels * 4,
189-
attention_head_dim=in_channels,
188+
ff_inner_dim=in_channels * 2,
189+
attention_head_dim=in_channels // 2,
190190
activation_fn="gelu",
191191
)
192192
self.local_image_embedding = nn.Sequential(
193-
nn.Conv2d(4, in_channels * 8, 3, padding=1),
193+
nn.Conv2d(4, in_channels * 4, 3, padding=1),
194194
nn.SiLU(),
195195
nn.AdaptiveAvgPool2d((32, 32)),
196-
nn.Conv2d(in_channels * 8, in_channels * 16, 3, stride=2, padding=1),
196+
nn.Conv2d(in_channels * 4, in_channels * 8, 3, stride=2, padding=1),
197197
nn.SiLU(),
198-
nn.Conv2d(in_channels * 16, 1024, 3, stride=2, padding=1),
198+
nn.Conv2d(in_channels * 8, cross_attention_dim, 3, stride=2, padding=1),
199199
)
200200
# print("local_image_embedding parameters", sum(p.numel() for p in self.local_image_embedding.parameters() if p.requires_grad))
201201

@@ -212,7 +212,7 @@ def __init__(
212212
self.context_embedding = nn.Sequential(
213213
nn.Linear(cross_attention_dim, time_embed_dim),
214214
nn.SiLU(),
215-
nn.Linear(time_embed_dim, cross_attention_dim * in_channels),
215+
nn.Linear(time_embed_dim, cross_attention_dim * in_channels // 2),
216216
)
217217
# print("context_embedding parameters", sum(p.numel() for p in self.context_embedding.parameters() if p.requires_grad))
218218
self.fps_embedding = nn.Sequential(

0 commit comments

Comments
 (0)