- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.4k
Support Multi/InfiniteTalk #10179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Support Multi/InfiniteTalk #10179
Changes from all commits
efe83f5
              460ce7f
              6f6db12
              00c069d
              57567bd
              9c5022e
              d0dce6b
              7842a5c
              99dc959
              4cbc1a6
              f5d53f2
              897ffeb
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -87,7 +87,7 @@ def qkv_fn_k(x): | |
| ) | ||
|  | ||
| x = self.o(x) | ||
| return x | ||
| return x, q, k | ||
|  | ||
|  | ||
| class WanT2VCrossAttention(WanSelfAttention): | ||
|  | @@ -178,7 +178,8 @@ def __init__(self, | |
| window_size=(-1, -1), | ||
| qk_norm=True, | ||
| cross_attn_norm=False, | ||
| eps=1e-6, operation_settings={}): | ||
| eps=1e-6, operation_settings={}, | ||
| block_idx=None): | ||
| super().__init__() | ||
| self.dim = dim | ||
| self.ffn_dim = ffn_dim | ||
|  | @@ -187,6 +188,7 @@ def __init__(self, | |
| self.qk_norm = qk_norm | ||
| self.cross_attn_norm = cross_attn_norm | ||
| self.eps = eps | ||
| self.block_idx = block_idx | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of having the  | ||
|  | ||
| # layers | ||
| self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
|  | @@ -225,14 +227,16 @@ def forward( | |
| """ | ||
| # assert e.dtype == torch.float32 | ||
|  | ||
| patches = transformer_options.get("patches", {}) | ||
|  | ||
| if e.ndim < 4: | ||
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) | ||
| else: | ||
| e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) | ||
| # assert e[0].dtype == torch.float32 | ||
|  | ||
| # self-attention | ||
| y = self.self_attn( | ||
| y, q, k = self.self_attn( | ||
| torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), | ||
| freqs, transformer_options=transformer_options) | ||
|  | ||
|  | @@ -241,6 +245,11 @@ def forward( | |
|  | ||
| # cross-attention & ffn | ||
| x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) | ||
|  | ||
| if "cross_attn" in patches: | ||
| for p in patches["cross_attn"]: | ||
| x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options}) | ||
|  | ||
| y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) | ||
| x = torch.addcmul(x, y, repeat_e(e[5], x)) | ||
| return x | ||
|  | @@ -262,6 +271,7 @@ def __init__( | |
| ): | ||
| super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) | ||
| self.block_id = block_id | ||
| self.block_idx = None | ||
| if block_id == 0: | ||
| self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
| self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | ||
|  | @@ -486,8 +496,8 @@ def __init__(self, | |
| cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' | ||
| self.blocks = nn.ModuleList([ | ||
| wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, | ||
| window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) | ||
| for _ in range(num_layers) | ||
| window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i) | ||
| for i in range(num_layers) | ||
| ]) | ||
|  | ||
| # head | ||
|  | @@ -540,6 +550,7 @@ def forward_orig( | |
| # embeddings | ||
| x = self.patch_embedding(x.float()).to(x.dtype) | ||
| grid_sizes = x.shape[2:] | ||
| transformer_options["grid_sizes"] = grid_sizes | ||
| x = x.flatten(2).transpose(1, 2) | ||
|  | ||
| # time embeddings | ||
|  | @@ -722,6 +733,7 @@ def forward_orig( | |
| # embeddings | ||
| x = self.patch_embedding(x.float()).to(x.dtype) | ||
| grid_sizes = x.shape[2:] | ||
| transformer_options["grid_sizes"] = grid_sizes | ||
| x = x.flatten(2).transpose(1, 2) | ||
|  | ||
| # time embeddings | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some uncertainty about whether returning this will in general increase the memory peak of WAN within native ComfyUI. Instead, comfy suggests that you add a patch to replace the
x = optimized_attention(...)call on line 81 byreusing theModelPatcher.set_model_attn1_replacefunctionality (in unet, attn1 is self, attn2 is cross), which can then do the optimized_attention call + the partial attention thing that happens inside thecross_attnpatch. To get the q + k for thecross_attnpatch, you can store the q and k values in transformer_options instead and then pop them out after usage.The
transformer_indexcan stay None (not given) since that was something unique to unet models.It would probably be more optimal to not call optimized_attention anymore and just reuse the logic of hte slower partial attention thingy in this code, but comfy said he would be fine if you didn't go that far and just kept both within that attention replacement function.