|
8 | 8 |
|
9 | 9 | from comfy.ldm.modules.attention import optimized_attention |
10 | 10 | from comfy.ldm.flux.layers import EmbedND |
11 | | -from comfy.ldm.flux.math import apply_rope |
| 11 | +from comfy.ldm.flux.math import apply_rope1 |
12 | 12 | import comfy.ldm.common_dit |
13 | 13 | import comfy.model_management |
14 | 14 | import comfy.patcher_extension |
@@ -60,20 +60,24 @@ def forward(self, x, freqs, transformer_options={}): |
60 | 60 | """ |
61 | 61 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
62 | 62 |
|
63 | | - # query, key, value function |
64 | | - def qkv_fn(x): |
| 63 | + def qkv_fn_q(x): |
65 | 64 | q = self.norm_q(self.q(x)).view(b, s, n, d) |
| 65 | + return apply_rope1(q, freqs) |
| 66 | + |
| 67 | + def qkv_fn_k(x): |
66 | 68 | k = self.norm_k(self.k(x)).view(b, s, n, d) |
67 | | - v = self.v(x).view(b, s, n * d) |
68 | | - return q, k, v |
| 69 | + return apply_rope1(k, freqs) |
69 | 70 |
|
70 | | - q, k, v = qkv_fn(x) |
71 | | - q, k = apply_rope(q, k, freqs) |
| 71 | + #These two are VRAM hogs, so we want to do all of q computation and |
| 72 | + #have pytorch garbage collect the intermediates on the sub function |
| 73 | + #return before we touch k |
| 74 | + q = qkv_fn_q(x) |
| 75 | + k = qkv_fn_k(x) |
72 | 76 |
|
73 | 77 | x = optimized_attention( |
74 | 78 | q.view(b, s, n * d), |
75 | 79 | k.view(b, s, n * d), |
76 | | - v, |
| 80 | + self.v(x).view(b, s, n * d), |
77 | 81 | heads=self.num_heads, |
78 | 82 | transformer_options=transformer_options, |
79 | 83 | ) |
|
0 commit comments