Skip to content

Commit dd611a7

Browse files
Support the HuMo 17B model. (#9912)
1 parent 9288c78 commit dd611a7

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

comfy/ldm/wan/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ class AudioCrossAttentionWrapper(nn.Module):
13641364
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
13651365
super().__init__()
13661366

1367-
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
1367+
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm=qk_norm, kv_dim=kv_dim, eps=eps, operation_settings=operation_settings)
13681368
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
13691369

13701370
def forward(self, x, audio, transformer_options={}):

comfy/model_base.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,14 +1220,37 @@ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False
12201220

12211221
def extra_conds(self, **kwargs):
12221222
out = super().extra_conds(**kwargs)
1223+
noise = kwargs.get("noise", None)
12231224

12241225
audio_embed = kwargs.get("audio_embed", None)
12251226
if audio_embed is not None:
12261227
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
12271228

1228-
reference_latents = kwargs.get("reference_latents", None)
1229-
if reference_latents is not None:
1230-
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
1229+
if "c_concat" not in out: # 1.7B model
1230+
reference_latents = kwargs.get("reference_latents", None)
1231+
if reference_latents is not None:
1232+
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
1233+
else:
1234+
noise_shape = list(noise.shape)
1235+
noise_shape[1] += 4
1236+
concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
1237+
zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1)
1238+
zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1)
1239+
zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1)
1240+
concat_latent[:, 4:] = zero_vae_values
1241+
concat_latent[:, 4:, :1] = zero_vae_values_first
1242+
concat_latent[:, 4:, 1:2] = zero_vae_values_second
1243+
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_latent)
1244+
reference_latents = kwargs.get("reference_latents", None)
1245+
if reference_latents is not None:
1246+
ref_latent = self.process_latent_in(reference_latents[-1])
1247+
ref_latent_shape = list(ref_latent.shape)
1248+
ref_latent_shape[1] += 4 + ref_latent_shape[1]
1249+
ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype)
1250+
ref_latent_full[:, 20:] = ref_latent
1251+
ref_latent_full[:, 16:20] = 1.0
1252+
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent_full)
1253+
12311254
return out
12321255

12331256
class WAN22_S2V(WAN21):

0 commit comments

Comments
 (0)