@@ -412,9 +412,12 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
412412 self .working_dtypes = [torch .bfloat16 , torch .float32 ]
413413 elif "decoder.conv_in.conv.weight" in sd and sd ['decoder.conv_in.conv.weight' ].shape [1 ] == 32 :
414414 ddconfig = {"block_out_channels" : [128 , 256 , 512 , 1024 , 1024 ], "in_channels" : 3 , "out_channels" : 3 , "num_res_blocks" : 2 , "ffactor_spatial" : 16 , "ffactor_temporal" : 4 , "downsample_match_channel" : True , "upsample_match_channel" : True }
415- self .latent_channels = ddconfig ['z_channels' ] = sd ["decoder.conv_in.conv.weight" ].shape [1 ]
416- self .downscale_ratio = 16
417- self .upscale_ratio = 16
415+ ddconfig ['z_channels' ] = sd ["decoder.conv_in.conv.weight" ].shape [1 ]
416+ self .latent_channels = 64
417+ self .upscale_ratio = (lambda a : max (0 , a * 4 - 3 ), 16 , 16 )
418+ self .upscale_index_formula = (4 , 16 , 16 )
419+ self .downscale_ratio = (lambda a : max (0 , math .floor ((a + 3 ) / 4 )), 16 , 16 )
420+ self .downscale_index_formula = (4 , 16 , 16 )
418421 self .latent_dim = 3
419422 self .not_video = True
420423 self .working_dtypes = [torch .float16 , torch .bfloat16 , torch .float32 ]
@@ -684,8 +687,11 @@ def encode(self, pixel_samples):
684687 self .throw_exception_if_invalid ()
685688 pixel_samples = self .vae_encode_crop_pixels (pixel_samples )
686689 pixel_samples = pixel_samples .movedim (- 1 , 1 )
687- if not self .not_video and self .latent_dim == 3 and pixel_samples .ndim < 5 :
688- pixel_samples = pixel_samples .movedim (1 , 0 ).unsqueeze (0 )
690+ if self .latent_dim == 3 and pixel_samples .ndim < 5 :
691+ if not self .not_video :
692+ pixel_samples = pixel_samples .movedim (1 , 0 ).unsqueeze (0 )
693+ else :
694+ pixel_samples = pixel_samples .unsqueeze (2 )
689695 try :
690696 memory_used = self .memory_used_encode (pixel_samples .shape , self .vae_dtype )
691697 model_management .load_models_gpu ([self .patcher ], memory_required = memory_used , force_full_load = self .disable_offload )
@@ -719,7 +725,10 @@ def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, ti
719725 dims = self .latent_dim
720726 pixel_samples = pixel_samples .movedim (- 1 , 1 )
721727 if dims == 3 :
722- pixel_samples = pixel_samples .movedim (1 , 0 ).unsqueeze (0 )
728+ if not self .not_video :
729+ pixel_samples = pixel_samples .movedim (1 , 0 ).unsqueeze (0 )
730+ else :
731+ pixel_samples = pixel_samples .unsqueeze (2 )
723732
724733 memory_used = self .memory_used_encode (pixel_samples .shape , self .vae_dtype ) # TODO: calculate mem required for tile
725734 model_management .load_models_gpu ([self .patcher ], memory_required = memory_used , force_full_load = self .disable_offload )
0 commit comments