Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 73 additions & 39 deletions comfy/ldm/hunyuan_video/vae_refiner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
Expand All @@ -17,11 +17,12 @@ def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma

class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae

self.tds = tds
self.gs = fct * ic // oc
Expand All @@ -30,7 +31,7 @@ def forward(self, x):
r1 = 2 if self.tds else 1
h = self.conv(x)

if self.tds:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
Expand Down Expand Up @@ -66,6 +67,7 @@ def forward(self, x):
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape

nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
Expand All @@ -83,10 +85,11 @@ def forward(self, x):


class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
self.refiner_vae = refiner_vae

self.tus = tus
self.rp = fct * oc // ic
Expand All @@ -95,7 +98,7 @@ def forward(self, x):
r1 = 2 if self.tus else 1
h = self.conv(x)

if self.tus:
if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
Expand Down Expand Up @@ -148,43 +151,56 @@ def forward(self, x):

class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.ffactor_temporal = ffactor_temporal

self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize

self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)

self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.down.append(stage)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)

self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)

self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()

def forward(self, x):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)

x = self.conv_in(x)

for stage in self.down:
Expand All @@ -200,31 +216,42 @@ def forward(self, x):
skip = x.view(b, c // grp, grp, t, h, w).mean(2)

out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]

out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
if self.refiner_vae:
out = self.regul(out)[0]

out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()

return out

class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks

self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
norm_op = Normalize

ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)

self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
Expand All @@ -235,25 +262,26 @@ def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
ch = nxt
self.up.append(stage)

self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)

def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]

x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
Expand All @@ -264,4 +292,10 @@ def forward(self, z):
if hasattr(stage, 'upsample'):
x = stage.upsample(x)

return self.conv_out(F.silu(self.norm_out(x)))
out = self.conv_out(F.silu(self.norm_out(x)))

if not self.refiner_vae:
if z.shape[-3] == 1:
out = out[:, :, -1:]

return out
68 changes: 42 additions & 26 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,35 +332,51 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})

self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)

elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}

if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})

self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})

self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}

if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4

self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
self.first_stage_model = AudioOobleckVAE()
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
Expand Down