Skip to content

Commit ae79281

Browse files
rattus128adlerfaulkner
authored andcommitted
WAN2.2: Fix cache VRAM leak on error (comfyanonymous#10308)
Same change pattern as 7e8dd27 applied to WAN2.2 If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack.
1 parent dbe06ad commit ae79281

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

comfy/ldm/wan/vae2_2.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -657,51 +657,51 @@ def __init__(
657657
)
658658

659659
def encode(self, x):
660-
self.clear_cache()
660+
conv_idx = [0]
661+
feat_map = [None] * count_conv3d(self.encoder)
661662
x = patchify(x, patch_size=2)
662663
t = x.shape[2]
663664
iter_ = 1 + (t - 1) // 4
664665
for i in range(iter_):
665-
self._enc_conv_idx = [0]
666+
conv_idx = [0]
666667
if i == 0:
667668
out = self.encoder(
668669
x[:, :, :1, :, :],
669-
feat_cache=self._enc_feat_map,
670-
feat_idx=self._enc_conv_idx,
670+
feat_cache=feat_map,
671+
feat_idx=conv_idx,
671672
)
672673
else:
673674
out_ = self.encoder(
674675
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
675-
feat_cache=self._enc_feat_map,
676-
feat_idx=self._enc_conv_idx,
676+
feat_cache=feat_map,
677+
feat_idx=conv_idx,
677678
)
678679
out = torch.cat([out, out_], 2)
679680
mu, log_var = self.conv1(out).chunk(2, dim=1)
680-
self.clear_cache()
681681
return mu
682682

683683
def decode(self, z):
684-
self.clear_cache()
684+
conv_idx = [0]
685+
feat_map = [None] * count_conv3d(self.decoder)
685686
iter_ = z.shape[2]
686687
x = self.conv2(z)
687688
for i in range(iter_):
688-
self._conv_idx = [0]
689+
conv_idx = [0]
689690
if i == 0:
690691
out = self.decoder(
691692
x[:, :, i:i + 1, :, :],
692-
feat_cache=self._feat_map,
693-
feat_idx=self._conv_idx,
693+
feat_cache=feat_map,
694+
feat_idx=conv_idx,
694695
first_chunk=True,
695696
)
696697
else:
697698
out_ = self.decoder(
698699
x[:, :, i:i + 1, :, :],
699-
feat_cache=self._feat_map,
700-
feat_idx=self._conv_idx,
700+
feat_cache=feat_map,
701+
feat_idx=conv_idx,
701702
)
702703
out = torch.cat([out, out_], 2)
703704
out = unpatchify(out, patch_size=2)
704-
self.clear_cache()
705705
return out
706706

707707
def reparameterize(self, mu, log_var):
@@ -715,12 +715,3 @@ def sample(self, imgs, deterministic=False):
715715
return mu
716716
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
717717
return mu + std * torch.randn_like(std)
718-
719-
def clear_cache(self):
720-
self._conv_num = count_conv3d(self.decoder)
721-
self._conv_idx = [0]
722-
self._feat_map = [None] * self._conv_num
723-
# cache encode
724-
self._enc_conv_num = count_conv3d(self.encoder)
725-
self._enc_conv_idx = [0]
726-
self._enc_feat_map = [None] * self._enc_conv_num

0 commit comments

Comments
 (0)