Skip to content

Commit c0442f1

Browse files
Hunyuan refiner vae now works with tiled. (comfyanonymous#9836)
1 parent 545642f commit c0442f1

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

comfy/ldm/hunyuan_video/vae_refiner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
185185
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
186186

187187
def forward(self, x):
188-
x = x.unsqueeze(2)
189188
x = self.conv_in(x)
190189

191190
for stage in self.down:

comfy/sd.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,12 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
412412
self.working_dtypes = [torch.bfloat16, torch.float32]
413413
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
414414
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
415-
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
416-
self.downscale_ratio = 16
417-
self.upscale_ratio = 16
415+
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
416+
self.latent_channels = 64
417+
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
418+
self.upscale_index_formula = (4, 16, 16)
419+
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
420+
self.downscale_index_formula = (4, 16, 16)
418421
self.latent_dim = 3
419422
self.not_video = True
420423
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@@ -684,8 +687,11 @@ def encode(self, pixel_samples):
684687
self.throw_exception_if_invalid()
685688
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
686689
pixel_samples = pixel_samples.movedim(-1, 1)
687-
if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5:
688-
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
690+
if self.latent_dim == 3 and pixel_samples.ndim < 5:
691+
if not self.not_video:
692+
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
693+
else:
694+
pixel_samples = pixel_samples.unsqueeze(2)
689695
try:
690696
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
691697
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@@ -719,7 +725,10 @@ def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, ti
719725
dims = self.latent_dim
720726
pixel_samples = pixel_samples.movedim(-1, 1)
721727
if dims == 3:
722-
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
728+
if not self.not_video:
729+
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
730+
else:
731+
pixel_samples = pixel_samples.unsqueeze(2)
723732

724733
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
725734
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)

0 commit comments

Comments
 (0)