@@ -1220,14 +1220,37 @@ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False
12201220
12211221 def extra_conds (self , ** kwargs ):
12221222 out = super ().extra_conds (** kwargs )
1223+ noise = kwargs .get ("noise" , None )
12231224
12241225 audio_embed = kwargs .get ("audio_embed" , None )
12251226 if audio_embed is not None :
12261227 out ['audio_embed' ] = comfy .conds .CONDRegular (audio_embed )
12271228
1228- reference_latents = kwargs .get ("reference_latents" , None )
1229- if reference_latents is not None :
1230- out ['reference_latent' ] = comfy .conds .CONDRegular (self .process_latent_in (reference_latents [- 1 ]))
1229+ if "c_concat" not in out : # 1.7B model
1230+ reference_latents = kwargs .get ("reference_latents" , None )
1231+ if reference_latents is not None :
1232+ out ['reference_latent' ] = comfy .conds .CONDRegular (self .process_latent_in (reference_latents [- 1 ]))
1233+ else :
1234+ noise_shape = list (noise .shape )
1235+ noise_shape [1 ] += 4
1236+ concat_latent = torch .zeros (noise_shape , device = noise .device , dtype = noise .dtype )
1237+ zero_vae_values_first = torch .tensor ([0.8660 , - 0.4326 , - 0.0017 , - 0.4884 , - 0.5283 , 0.9207 , - 0.9896 , 0.4433 , - 0.5543 , - 0.0113 , 0.5753 , - 0.6000 , - 0.8346 , - 0.3497 , - 0.1926 , - 0.6938 ]).view (1 , 16 , 1 , 1 , 1 )
1238+ zero_vae_values_second = torch .tensor ([1.0869 , - 1.2370 , 0.0206 , - 0.4357 , - 0.6411 , 2.0307 , - 1.5972 , 1.2659 , - 0.8595 , - 0.4654 , 0.9638 , - 1.6330 , - 1.4310 , - 0.1098 , - 0.3856 , - 1.4583 ]).view (1 , 16 , 1 , 1 , 1 )
1239+ zero_vae_values = torch .tensor ([0.8642 , - 1.8583 , 0.1577 , 0.1350 , - 0.3641 , 2.5863 , - 1.9670 , 1.6065 , - 1.0475 , - 0.8678 , 1.1734 , - 1.8138 , - 1.5933 , - 0.7721 , - 0.3289 , - 1.3745 ]).view (1 , 16 , 1 , 1 , 1 )
1240+ concat_latent [:, 4 :] = zero_vae_values
1241+ concat_latent [:, 4 :, :1 ] = zero_vae_values_first
1242+ concat_latent [:, 4 :, 1 :2 ] = zero_vae_values_second
1243+ out ['c_concat' ] = comfy .conds .CONDNoiseShape (concat_latent )
1244+ reference_latents = kwargs .get ("reference_latents" , None )
1245+ if reference_latents is not None :
1246+ ref_latent = self .process_latent_in (reference_latents [- 1 ])
1247+ ref_latent_shape = list (ref_latent .shape )
1248+ ref_latent_shape [1 ] += 4 + ref_latent_shape [1 ]
1249+ ref_latent_full = torch .zeros (ref_latent_shape , device = ref_latent .device , dtype = ref_latent .dtype )
1250+ ref_latent_full [:, 20 :] = ref_latent
1251+ ref_latent_full [:, 16 :20 ] = 1.0
1252+ out ['reference_latent' ] = comfy .conds .CONDRegular (ref_latent_full )
1253+
12311254 return out
12321255
12331256class WAN22_S2V (WAN21 ):
0 commit comments