Skip to content

Commit e1ab6bb

Browse files
authored
EasyCache: Fix for mismatch in input/output channels with some models (#10788)
Slices model input with output channels so the caching tracks only the noise channels, resolves channel mismatch with models like WanVideo I2V Also fix for slicing deprecation in pytorch 2.9
1 parent 048f49a commit e1ab6bb

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

comfy_extras/nodes_easycache.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
def easycache_forward_wrapper(executor, *args, **kwargs):
1313
# get values from args
14-
x: torch.Tensor = args[0]
1514
transformer_options: dict[str] = args[-1]
1615
if not isinstance(transformer_options, dict):
1716
transformer_options = kwargs.get("transformer_options")
1817
if not transformer_options:
1918
transformer_options = args[-2]
2019
easycache: EasyCacheHolder = transformer_options["easycache"]
20+
x: torch.Tensor = args[0][:, :easycache.output_channels]
2121
sigmas = transformer_options["sigmas"]
2222
uuids = transformer_options["uuids"]
2323
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
@@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
8282

8383
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
8484
# get values from args
85-
x: torch.Tensor = args[0]
8685
timestep: float = args[1]
8786
model_options: dict[str] = args[2]
8887
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
8988
if easycache.is_past_end_timestep(timestep):
9089
return executor(*args, **kwargs)
9190
# prepare next x_prev
91+
x: torch.Tensor = args[0][:, :easycache.output_channels]
9292
next_x_prev = x
9393
input_change = None
9494
do_easycache = easycache.should_do_easycache(timestep)
@@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
173173

174174

175175
class EasyCacheHolder:
176-
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
176+
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
177177
self.name = "EasyCache"
178178
self.reuse_threshold = reuse_threshold
179179
self.start_percent = start_percent
@@ -202,6 +202,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
202202
self.allow_mismatch = True
203203
self.cut_from_start = True
204204
self.state_metadata = None
205+
self.output_channels = output_channels
205206

206207
def is_past_end_timestep(self, timestep: float) -> bool:
207208
return not (timestep[0] > self.end_t).item()
@@ -264,7 +265,7 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
264265
else:
265266
slicing.append(slice(None))
266267
batch_slice = batch_slice + slicing
267-
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
268+
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
268269
return x
269270

270271
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
@@ -283,7 +284,7 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
283284
else:
284285
slicing.append(slice(None))
285286
skip_dim = False
286-
x = x[slicing]
287+
x = x[tuple(slicing)]
287288
diff = output - x
288289
batch_offset = diff.shape[0] // len(uuids)
289290
for i, uuid in enumerate(uuids):
@@ -323,7 +324,7 @@ def reset(self):
323324
return self
324325

325326
def clone(self):
326-
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
327+
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
327328

328329

329330
class EasyCacheNode(io.ComfyNode):
@@ -350,15 +351,15 @@ def define_schema(cls) -> io.Schema:
350351
@classmethod
351352
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
352353
model = model.clone()
353-
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
354+
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
354355
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
355356
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
356357
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
357358
return io.NodeOutput(model)
358359

359360

360361
class LazyCacheHolder:
361-
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
362+
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
362363
self.name = "LazyCache"
363364
self.reuse_threshold = reuse_threshold
364365
self.start_percent = start_percent
@@ -382,6 +383,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
382383
self.approx_output_change_rates = []
383384
self.total_steps_skipped = 0
384385
self.state_metadata = None
386+
self.output_channels = output_channels
385387

386388
def has_cache_diff(self) -> bool:
387389
return self.cache_diff is not None
@@ -456,7 +458,7 @@ def reset(self):
456458
return self
457459

458460
def clone(self):
459-
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
461+
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
460462

461463
class LazyCacheNode(io.ComfyNode):
462464
@classmethod
@@ -482,7 +484,7 @@ def define_schema(cls) -> io.Schema:
482484
@classmethod
483485
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
484486
model = model.clone()
485-
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
487+
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
486488
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
487489
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
488490
return io.NodeOutput(model)

0 commit comments

Comments
 (0)