11import torch
22import torch .nn as nn
33import torch .nn .functional as F
4- from comfy .ldm .modules .diffusionmodules .model import ResnetBlock , AttnBlock , VideoConv3d
4+ from comfy .ldm .modules .diffusionmodules .model import ResnetBlock , AttnBlock , VideoConv3d , Normalize
55import comfy .ops
66import comfy .ldm .models .autoencoder
77ops = comfy .ops .disable_weight_init
@@ -17,11 +17,12 @@ def forward(self, x):
1717 return F .normalize (x , dim = 1 ) * self .scale * self .gamma
1818
1919class DnSmpl (nn .Module ):
20- def __init__ (self , ic , oc , tds = True ):
20+ def __init__ (self , ic , oc , tds = True , refiner_vae = True , op = VideoConv3d ):
2121 super ().__init__ ()
2222 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
2323 assert oc % fct == 0
24- self .conv = VideoConv3d (ic , oc // fct , kernel_size = 3 )
24+ self .conv = op (ic , oc // fct , kernel_size = 3 , stride = 1 , padding = 1 )
25+ self .refiner_vae = refiner_vae
2526
2627 self .tds = tds
2728 self .gs = fct * ic // oc
@@ -30,7 +31,7 @@ def forward(self, x):
3031 r1 = 2 if self .tds else 1
3132 h = self .conv (x )
3233
33- if self .tds :
34+ if self .tds and self . refiner_vae :
3435 hf = h [:, :, :1 , :, :]
3536 b , c , f , ht , wd = hf .shape
3637 hf = hf .reshape (b , c , f , ht // 2 , 2 , wd // 2 , 2 )
@@ -66,6 +67,7 @@ def forward(self, x):
6667 sc = torch .cat ([xf , xn ], dim = 2 )
6768 else :
6869 b , c , frms , ht , wd = h .shape
70+
6971 nf = frms // r1
7072 h = h .reshape (b , c , nf , r1 , ht // 2 , 2 , wd // 2 , 2 )
7173 h = h .permute (0 , 3 , 5 , 7 , 1 , 2 , 4 , 6 )
@@ -83,10 +85,11 @@ def forward(self, x):
8385
8486
8587class UpSmpl (nn .Module ):
86- def __init__ (self , ic , oc , tus = True ):
88+ def __init__ (self , ic , oc , tus = True , refiner_vae = True , op = VideoConv3d ):
8789 super ().__init__ ()
8890 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
89- self .conv = VideoConv3d (ic , oc * fct , kernel_size = 3 )
91+ self .conv = op (ic , oc * fct , kernel_size = 3 , stride = 1 , padding = 1 )
92+ self .refiner_vae = refiner_vae
9093
9194 self .tus = tus
9295 self .rp = fct * oc // ic
@@ -95,7 +98,7 @@ def forward(self, x):
9598 r1 = 2 if self .tus else 1
9699 h = self .conv (x )
97100
98- if self .tus :
101+ if self .tus and self . refiner_vae :
99102 hf = h [:, :, :1 , :, :]
100103 b , c , f , ht , wd = hf .shape
101104 nc = c // (2 * 2 )
@@ -148,43 +151,56 @@ def forward(self, x):
148151
149152class Encoder (nn .Module ):
150153 def __init__ (self , in_channels , z_channels , block_out_channels , num_res_blocks ,
151- ffactor_spatial , ffactor_temporal , downsample_match_channel = True , ** _ ):
154+ ffactor_spatial , ffactor_temporal , downsample_match_channel = True , refiner_vae = True , ** _ ):
152155 super ().__init__ ()
153156 self .z_channels = z_channels
154157 self .block_out_channels = block_out_channels
155158 self .num_res_blocks = num_res_blocks
156- self .conv_in = VideoConv3d (in_channels , block_out_channels [0 ], 3 , 1 , 1 )
159+ self .ffactor_temporal = ffactor_temporal
160+
161+ self .refiner_vae = refiner_vae
162+ if self .refiner_vae :
163+ conv_op = VideoConv3d
164+ norm_op = RMS_norm
165+ else :
166+ conv_op = ops .Conv3d
167+ norm_op = Normalize
168+
169+ self .conv_in = conv_op (in_channels , block_out_channels [0 ], 3 , 1 , 1 )
157170
158171 self .down = nn .ModuleList ()
159172 ch = block_out_channels [0 ]
160173 depth = (ffactor_spatial >> 1 ).bit_length ()
161- depth_temporal = ((ffactor_spatial // ffactor_temporal ) >> 1 ).bit_length ()
174+ depth_temporal = ((ffactor_spatial // self . ffactor_temporal ) >> 1 ).bit_length ()
162175
163176 for i , tgt in enumerate (block_out_channels ):
164177 stage = nn .Module ()
165178 stage .block = nn .ModuleList ([ResnetBlock (in_channels = ch if j == 0 else tgt ,
166179 out_channels = tgt ,
167180 temb_channels = 0 ,
168- conv_op = VideoConv3d , norm_op = RMS_norm )
181+ conv_op = conv_op , norm_op = norm_op )
169182 for j in range (num_res_blocks )])
170183 ch = tgt
171184 if i < depth :
172185 nxt = block_out_channels [i + 1 ] if i + 1 < len (block_out_channels ) and downsample_match_channel else ch
173- stage .downsample = DnSmpl (ch , nxt , tds = i >= depth_temporal )
186+ stage .downsample = DnSmpl (ch , nxt , tds = i >= depth_temporal , refiner_vae = self . refiner_vae , op = conv_op )
174187 ch = nxt
175188 self .down .append (stage )
176189
177190 self .mid = nn .Module ()
178- self .mid .block_1 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = VideoConv3d , norm_op = RMS_norm )
179- self .mid .attn_1 = AttnBlock (ch , conv_op = ops .Conv3d , norm_op = RMS_norm )
180- self .mid .block_2 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = VideoConv3d , norm_op = RMS_norm )
191+ self .mid .block_1 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = conv_op , norm_op = norm_op )
192+ self .mid .attn_1 = AttnBlock (ch , conv_op = ops .Conv3d , norm_op = norm_op )
193+ self .mid .block_2 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = conv_op , norm_op = norm_op )
181194
182- self .norm_out = RMS_norm (ch )
183- self .conv_out = VideoConv3d (ch , z_channels << 1 , 3 , 1 , 1 )
195+ self .norm_out = norm_op (ch )
196+ self .conv_out = conv_op (ch , z_channels << 1 , 3 , 1 , 1 )
184197
185198 self .regul = comfy .ldm .models .autoencoder .DiagonalGaussianRegularizer ()
186199
187200 def forward (self , x ):
201+ if not self .refiner_vae and x .shape [2 ] == 1 :
202+ x = x .expand (- 1 , - 1 , self .ffactor_temporal , - 1 , - 1 )
203+
188204 x = self .conv_in (x )
189205
190206 for stage in self .down :
@@ -200,31 +216,42 @@ def forward(self, x):
200216 skip = x .view (b , c // grp , grp , t , h , w ).mean (2 )
201217
202218 out = self .conv_out (F .silu (self .norm_out (x ))) + skip
203- out = self .regul (out )[0 ]
204219
205- out = torch .cat ((out [:, :, :1 ], out ), dim = 2 )
206- out = out .permute (0 , 2 , 1 , 3 , 4 )
207- b , f_times_2 , c , h , w = out .shape
208- out = out .reshape (b , f_times_2 // 2 , 2 * c , h , w )
209- out = out .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
220+ if self .refiner_vae :
221+ out = self .regul (out )[0 ]
222+
223+ out = torch .cat ((out [:, :, :1 ], out ), dim = 2 )
224+ out = out .permute (0 , 2 , 1 , 3 , 4 )
225+ b , f_times_2 , c , h , w = out .shape
226+ out = out .reshape (b , f_times_2 // 2 , 2 * c , h , w )
227+ out = out .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
228+
210229 return out
211230
212231class Decoder (nn .Module ):
213232 def __init__ (self , z_channels , out_channels , block_out_channels , num_res_blocks ,
214- ffactor_spatial , ffactor_temporal , upsample_match_channel = True , ** _ ):
233+ ffactor_spatial , ffactor_temporal , upsample_match_channel = True , refiner_vae = True , ** _ ):
215234 super ().__init__ ()
216235 block_out_channels = block_out_channels [::- 1 ]
217236 self .z_channels = z_channels
218237 self .block_out_channels = block_out_channels
219238 self .num_res_blocks = num_res_blocks
220239
240+ self .refiner_vae = refiner_vae
241+ if self .refiner_vae :
242+ conv_op = VideoConv3d
243+ norm_op = RMS_norm
244+ else :
245+ conv_op = ops .Conv3d
246+ norm_op = Normalize
247+
221248 ch = block_out_channels [0 ]
222- self .conv_in = VideoConv3d (z_channels , ch , 3 )
249+ self .conv_in = conv_op (z_channels , ch , kernel_size = 3 , stride = 1 , padding = 1 )
223250
224251 self .mid = nn .Module ()
225- self .mid .block_1 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = VideoConv3d , norm_op = RMS_norm )
226- self .mid .attn_1 = AttnBlock (ch , conv_op = ops .Conv3d , norm_op = RMS_norm )
227- self .mid .block_2 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = VideoConv3d , norm_op = RMS_norm )
252+ self .mid .block_1 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = conv_op , norm_op = norm_op )
253+ self .mid .attn_1 = AttnBlock (ch , conv_op = ops .Conv3d , norm_op = norm_op )
254+ self .mid .block_2 = ResnetBlock (in_channels = ch , out_channels = ch , temb_channels = 0 , conv_op = conv_op , norm_op = norm_op )
228255
229256 self .up = nn .ModuleList ()
230257 depth = (ffactor_spatial >> 1 ).bit_length ()
@@ -235,25 +262,26 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
235262 stage .block = nn .ModuleList ([ResnetBlock (in_channels = ch if j == 0 else tgt ,
236263 out_channels = tgt ,
237264 temb_channels = 0 ,
238- conv_op = VideoConv3d , norm_op = RMS_norm )
265+ conv_op = conv_op , norm_op = norm_op )
239266 for j in range (num_res_blocks + 1 )])
240267 ch = tgt
241268 if i < depth :
242269 nxt = block_out_channels [i + 1 ] if i + 1 < len (block_out_channels ) and upsample_match_channel else ch
243- stage .upsample = UpSmpl (ch , nxt , tus = i < depth_temporal )
270+ stage .upsample = UpSmpl (ch , nxt , tus = i < depth_temporal , refiner_vae = self . refiner_vae , op = conv_op )
244271 ch = nxt
245272 self .up .append (stage )
246273
247- self .norm_out = RMS_norm (ch )
248- self .conv_out = VideoConv3d (ch , out_channels , 3 )
274+ self .norm_out = norm_op (ch )
275+ self .conv_out = conv_op (ch , out_channels , 3 , stride = 1 , padding = 1 )
249276
250277 def forward (self , z ):
251- z = z .permute (0 , 2 , 1 , 3 , 4 )
252- b , f , c , h , w = z .shape
253- z = z .reshape (b , f , 2 , c // 2 , h , w )
254- z = z .permute (0 , 1 , 2 , 3 , 4 , 5 ).reshape (b , f * 2 , c // 2 , h , w )
255- z = z .permute (0 , 2 , 1 , 3 , 4 )
256- z = z [:, :, 1 :]
278+ if self .refiner_vae :
279+ z = z .permute (0 , 2 , 1 , 3 , 4 )
280+ b , f , c , h , w = z .shape
281+ z = z .reshape (b , f , 2 , c // 2 , h , w )
282+ z = z .permute (0 , 1 , 2 , 3 , 4 , 5 ).reshape (b , f * 2 , c // 2 , h , w )
283+ z = z .permute (0 , 2 , 1 , 3 , 4 )
284+ z = z [:, :, 1 :]
257285
258286 x = self .conv_in (z ) + z .repeat_interleave (self .block_out_channels [0 ] // self .z_channels , 1 )
259287 x = self .mid .block_2 (self .mid .attn_1 (self .mid .block_1 (x )))
@@ -264,4 +292,10 @@ def forward(self, z):
264292 if hasattr (stage , 'upsample' ):
265293 x = stage .upsample (x )
266294
267- return self .conv_out (F .silu (self .norm_out (x )))
295+ out = self .conv_out (F .silu (self .norm_out (x )))
296+
297+ if not self .refiner_vae :
298+ if z .shape [- 3 ] == 1 :
299+ out = out [:, :, - 1 :]
300+
301+ return out
0 commit comments