Skip to content

Commit 47e46ca

Browse files
Support the new hunyuan vae. (comfyanonymous#10150)
1 parent cfcffeb commit 47e46ca

File tree

2 files changed

+115
-65
lines changed

2 files changed

+115
-65
lines changed

comfy/ldm/hunyuan_video/vae_refiner.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
4+
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
55
import comfy.ops
66
import comfy.ldm.models.autoencoder
77
ops = comfy.ops.disable_weight_init
@@ -17,11 +17,12 @@ def forward(self, x):
1717
return F.normalize(x, dim=1) * self.scale * self.gamma
1818

1919
class DnSmpl(nn.Module):
20-
def __init__(self, ic, oc, tds=True):
20+
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
2121
super().__init__()
2222
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
2323
assert oc % fct == 0
24-
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
24+
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
25+
self.refiner_vae = refiner_vae
2526

2627
self.tds = tds
2728
self.gs = fct * ic // oc
@@ -30,7 +31,7 @@ def forward(self, x):
3031
r1 = 2 if self.tds else 1
3132
h = self.conv(x)
3233

33-
if self.tds:
34+
if self.tds and self.refiner_vae:
3435
hf = h[:, :, :1, :, :]
3536
b, c, f, ht, wd = hf.shape
3637
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@@ -66,6 +67,7 @@ def forward(self, x):
6667
sc = torch.cat([xf, xn], dim=2)
6768
else:
6869
b, c, frms, ht, wd = h.shape
70+
6971
nf = frms // r1
7072
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
7173
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@@ -83,10 +85,11 @@ def forward(self, x):
8385

8486

8587
class UpSmpl(nn.Module):
86-
def __init__(self, ic, oc, tus=True):
88+
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
8789
super().__init__()
8890
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
89-
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
91+
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
92+
self.refiner_vae = refiner_vae
9093

9194
self.tus = tus
9295
self.rp = fct * oc // ic
@@ -95,7 +98,7 @@ def forward(self, x):
9598
r1 = 2 if self.tus else 1
9699
h = self.conv(x)
97100

98-
if self.tus:
101+
if self.tus and self.refiner_vae:
99102
hf = h[:, :, :1, :, :]
100103
b, c, f, ht, wd = hf.shape
101104
nc = c // (2 * 2)
@@ -148,43 +151,56 @@ def forward(self, x):
148151

149152
class Encoder(nn.Module):
150153
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
151-
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
154+
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
152155
super().__init__()
153156
self.z_channels = z_channels
154157
self.block_out_channels = block_out_channels
155158
self.num_res_blocks = num_res_blocks
156-
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
159+
self.ffactor_temporal = ffactor_temporal
160+
161+
self.refiner_vae = refiner_vae
162+
if self.refiner_vae:
163+
conv_op = VideoConv3d
164+
norm_op = RMS_norm
165+
else:
166+
conv_op = ops.Conv3d
167+
norm_op = Normalize
168+
169+
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
157170

158171
self.down = nn.ModuleList()
159172
ch = block_out_channels[0]
160173
depth = (ffactor_spatial >> 1).bit_length()
161-
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
174+
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
162175

163176
for i, tgt in enumerate(block_out_channels):
164177
stage = nn.Module()
165178
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
166179
out_channels=tgt,
167180
temb_channels=0,
168-
conv_op=VideoConv3d, norm_op=RMS_norm)
181+
conv_op=conv_op, norm_op=norm_op)
169182
for j in range(num_res_blocks)])
170183
ch = tgt
171184
if i < depth:
172185
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
173-
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
186+
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
174187
ch = nxt
175188
self.down.append(stage)
176189

177190
self.mid = nn.Module()
178-
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
179-
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
180-
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
191+
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
192+
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
193+
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
181194

182-
self.norm_out = RMS_norm(ch)
183-
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
195+
self.norm_out = norm_op(ch)
196+
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
184197

185198
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
186199

187200
def forward(self, x):
201+
if not self.refiner_vae and x.shape[2] == 1:
202+
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
203+
188204
x = self.conv_in(x)
189205

190206
for stage in self.down:
@@ -200,31 +216,42 @@ def forward(self, x):
200216
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
201217

202218
out = self.conv_out(F.silu(self.norm_out(x))) + skip
203-
out = self.regul(out)[0]
204219

205-
out = torch.cat((out[:, :, :1], out), dim=2)
206-
out = out.permute(0, 2, 1, 3, 4)
207-
b, f_times_2, c, h, w = out.shape
208-
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
209-
out = out.permute(0, 2, 1, 3, 4).contiguous()
220+
if self.refiner_vae:
221+
out = self.regul(out)[0]
222+
223+
out = torch.cat((out[:, :, :1], out), dim=2)
224+
out = out.permute(0, 2, 1, 3, 4)
225+
b, f_times_2, c, h, w = out.shape
226+
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
227+
out = out.permute(0, 2, 1, 3, 4).contiguous()
228+
210229
return out
211230

212231
class Decoder(nn.Module):
213232
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
214-
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
233+
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
215234
super().__init__()
216235
block_out_channels = block_out_channels[::-1]
217236
self.z_channels = z_channels
218237
self.block_out_channels = block_out_channels
219238
self.num_res_blocks = num_res_blocks
220239

240+
self.refiner_vae = refiner_vae
241+
if self.refiner_vae:
242+
conv_op = VideoConv3d
243+
norm_op = RMS_norm
244+
else:
245+
conv_op = ops.Conv3d
246+
norm_op = Normalize
247+
221248
ch = block_out_channels[0]
222-
self.conv_in = VideoConv3d(z_channels, ch, 3)
249+
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
223250

224251
self.mid = nn.Module()
225-
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
226-
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
227-
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
252+
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
253+
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
254+
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
228255

229256
self.up = nn.ModuleList()
230257
depth = (ffactor_spatial >> 1).bit_length()
@@ -235,25 +262,26 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
235262
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
236263
out_channels=tgt,
237264
temb_channels=0,
238-
conv_op=VideoConv3d, norm_op=RMS_norm)
265+
conv_op=conv_op, norm_op=norm_op)
239266
for j in range(num_res_blocks + 1)])
240267
ch = tgt
241268
if i < depth:
242269
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
243-
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
270+
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
244271
ch = nxt
245272
self.up.append(stage)
246273

247-
self.norm_out = RMS_norm(ch)
248-
self.conv_out = VideoConv3d(ch, out_channels, 3)
274+
self.norm_out = norm_op(ch)
275+
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
249276

250277
def forward(self, z):
251-
z = z.permute(0, 2, 1, 3, 4)
252-
b, f, c, h, w = z.shape
253-
z = z.reshape(b, f, 2, c // 2, h, w)
254-
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
255-
z = z.permute(0, 2, 1, 3, 4)
256-
z = z[:, :, 1:]
278+
if self.refiner_vae:
279+
z = z.permute(0, 2, 1, 3, 4)
280+
b, f, c, h, w = z.shape
281+
z = z.reshape(b, f, 2, c // 2, h, w)
282+
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
283+
z = z.permute(0, 2, 1, 3, 4)
284+
z = z[:, :, 1:]
257285

258286
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
259287
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
@@ -264,4 +292,10 @@ def forward(self, z):
264292
if hasattr(stage, 'upsample'):
265293
x = stage.upsample(x)
266294

267-
return self.conv_out(F.silu(self.norm_out(x)))
295+
out = self.conv_out(F.silu(self.norm_out(x)))
296+
297+
if not self.refiner_vae:
298+
if z.shape[-3] == 1:
299+
out = out[:, :, -1:]
300+
301+
return out

comfy/sd.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -332,35 +332,51 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
332332
self.first_stage_model = StageC_coder()
333333
self.downscale_ratio = 32
334334
self.latent_channels = 16
335-
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
336-
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
337-
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
338-
self.downscale_ratio = 32
339-
self.upscale_ratio = 32
340-
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
341-
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
342-
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
343-
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
344-
345-
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
346-
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
347-
348335
elif "decoder.conv_in.weight" in sd:
349-
#default SD1.x/SD2.x VAE parameters
350-
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
351-
352-
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
353-
ddconfig['ch_mult'] = [1, 2, 4]
354-
self.downscale_ratio = 4
355-
self.upscale_ratio = 4
336+
if sd['decoder.conv_in.weight'].shape[1] == 64:
337+
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
338+
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
339+
self.downscale_ratio = 32
340+
self.upscale_ratio = 32
341+
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
342+
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
343+
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
344+
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
345+
346+
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
347+
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
348+
elif sd['decoder.conv_in.weight'].shape[1] == 32:
349+
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
350+
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
351+
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
352+
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
353+
self.upscale_index_formula = (4, 16, 16)
354+
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
355+
self.downscale_index_formula = (4, 16, 16)
356+
self.latent_dim = 3
357+
self.not_video = True
358+
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
359+
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
360+
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
356361

357-
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
358-
if 'post_quant_conv.weight' in sd:
359-
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
362+
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
363+
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
360364
else:
361-
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
362-
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
363-
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
365+
#default SD1.x/SD2.x VAE parameters
366+
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
367+
368+
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
369+
ddconfig['ch_mult'] = [1, 2, 4]
370+
self.downscale_ratio = 4
371+
self.upscale_ratio = 4
372+
373+
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
374+
if 'post_quant_conv.weight' in sd:
375+
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
376+
else:
377+
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
378+
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
379+
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
364380
elif "decoder.layers.1.layers.0.beta" in sd:
365381
self.first_stage_model = AudioOobleckVAE()
366382
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)

0 commit comments

Comments
 (0)