1111
1212def easycache_forward_wrapper (executor , * args , ** kwargs ):
1313 # get values from args
14- x : torch .Tensor = args [0 ]
1514 transformer_options : dict [str ] = args [- 1 ]
1615 if not isinstance (transformer_options , dict ):
1716 transformer_options = kwargs .get ("transformer_options" )
1817 if not transformer_options :
1918 transformer_options = args [- 2 ]
2019 easycache : EasyCacheHolder = transformer_options ["easycache" ]
20+ x : torch .Tensor = args [0 ][:, :easycache .output_channels ]
2121 sigmas = transformer_options ["sigmas" ]
2222 uuids = transformer_options ["uuids" ]
2323 if sigmas is not None and easycache .is_past_end_timestep (sigmas ):
@@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
8282
8383def lazycache_predict_noise_wrapper (executor , * args , ** kwargs ):
8484 # get values from args
85- x : torch .Tensor = args [0 ]
8685 timestep : float = args [1 ]
8786 model_options : dict [str ] = args [2 ]
8887 easycache : LazyCacheHolder = model_options ["transformer_options" ]["easycache" ]
8988 if easycache .is_past_end_timestep (timestep ):
9089 return executor (* args , ** kwargs )
9190 # prepare next x_prev
91+ x : torch .Tensor = args [0 ][:, :easycache .output_channels ]
9292 next_x_prev = x
9393 input_change = None
9494 do_easycache = easycache .should_do_easycache (timestep )
@@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
173173
174174
175175class EasyCacheHolder :
176- def __init__ (self , reuse_threshold : float , start_percent : float , end_percent : float , subsample_factor : int , offload_cache_diff : bool , verbose : bool = False ):
176+ def __init__ (self , reuse_threshold : float , start_percent : float , end_percent : float , subsample_factor : int , offload_cache_diff : bool , verbose : bool = False , output_channels : int = None ):
177177 self .name = "EasyCache"
178178 self .reuse_threshold = reuse_threshold
179179 self .start_percent = start_percent
@@ -202,6 +202,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
202202 self .allow_mismatch = True
203203 self .cut_from_start = True
204204 self .state_metadata = None
205+ self .output_channels = output_channels
205206
206207 def is_past_end_timestep (self , timestep : float ) -> bool :
207208 return not (timestep [0 ] > self .end_t ).item ()
@@ -264,7 +265,7 @@ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
264265 else :
265266 slicing .append (slice (None ))
266267 batch_slice = batch_slice + slicing
267- x [batch_slice ] += self .uuid_cache_diffs [uuid ].to (x .device )
268+ x [tuple ( batch_slice ) ] += self .uuid_cache_diffs [uuid ].to (x .device )
268269 return x
269270
270271 def update_cache_diff (self , output : torch .Tensor , x : torch .Tensor , uuids : list [UUID ]):
@@ -283,7 +284,7 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
283284 else :
284285 slicing .append (slice (None ))
285286 skip_dim = False
286- x = x [slicing ]
287+ x = x [tuple ( slicing ) ]
287288 diff = output - x
288289 batch_offset = diff .shape [0 ] // len (uuids )
289290 for i , uuid in enumerate (uuids ):
@@ -323,7 +324,7 @@ def reset(self):
323324 return self
324325
325326 def clone (self ):
326- return EasyCacheHolder (self .reuse_threshold , self .start_percent , self .end_percent , self .subsample_factor , self .offload_cache_diff , self .verbose )
327+ return EasyCacheHolder (self .reuse_threshold , self .start_percent , self .end_percent , self .subsample_factor , self .offload_cache_diff , self .verbose , output_channels = self . output_channels )
327328
328329
329330class EasyCacheNode (io .ComfyNode ):
@@ -350,15 +351,15 @@ def define_schema(cls) -> io.Schema:
350351 @classmethod
351352 def execute (cls , model : io .Model .Type , reuse_threshold : float , start_percent : float , end_percent : float , verbose : bool ) -> io .NodeOutput :
352353 model = model .clone ()
353- model .model_options ["transformer_options" ]["easycache" ] = EasyCacheHolder (reuse_threshold , start_percent , end_percent , subsample_factor = 8 , offload_cache_diff = False , verbose = verbose )
354+ model .model_options ["transformer_options" ]["easycache" ] = EasyCacheHolder (reuse_threshold , start_percent , end_percent , subsample_factor = 8 , offload_cache_diff = False , verbose = verbose , output_channels = model . model . latent_format . latent_channels )
354355 model .add_wrapper_with_key (comfy .patcher_extension .WrappersMP .OUTER_SAMPLE , "easycache" , easycache_sample_wrapper )
355356 model .add_wrapper_with_key (comfy .patcher_extension .WrappersMP .CALC_COND_BATCH , "easycache" , easycache_calc_cond_batch_wrapper )
356357 model .add_wrapper_with_key (comfy .patcher_extension .WrappersMP .DIFFUSION_MODEL , "easycache" , easycache_forward_wrapper )
357358 return io .NodeOutput (model )
358359
359360
360361class LazyCacheHolder :
361- def __init__ (self , reuse_threshold : float , start_percent : float , end_percent : float , subsample_factor : int , offload_cache_diff : bool , verbose : bool = False ):
362+ def __init__ (self , reuse_threshold : float , start_percent : float , end_percent : float , subsample_factor : int , offload_cache_diff : bool , verbose : bool = False , output_channels : int = None ):
362363 self .name = "LazyCache"
363364 self .reuse_threshold = reuse_threshold
364365 self .start_percent = start_percent
@@ -382,6 +383,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
382383 self .approx_output_change_rates = []
383384 self .total_steps_skipped = 0
384385 self .state_metadata = None
386+ self .output_channels = output_channels
385387
386388 def has_cache_diff (self ) -> bool :
387389 return self .cache_diff is not None
@@ -456,7 +458,7 @@ def reset(self):
456458 return self
457459
458460 def clone (self ):
459- return LazyCacheHolder (self .reuse_threshold , self .start_percent , self .end_percent , self .subsample_factor , self .offload_cache_diff , self .verbose )
461+ return LazyCacheHolder (self .reuse_threshold , self .start_percent , self .end_percent , self .subsample_factor , self .offload_cache_diff , self .verbose , output_channels = self . output_channels )
460462
461463class LazyCacheNode (io .ComfyNode ):
462464 @classmethod
@@ -482,7 +484,7 @@ def define_schema(cls) -> io.Schema:
482484 @classmethod
483485 def execute (cls , model : io .Model .Type , reuse_threshold : float , start_percent : float , end_percent : float , verbose : bool ) -> io .NodeOutput :
484486 model = model .clone ()
485- model .model_options ["transformer_options" ]["easycache" ] = LazyCacheHolder (reuse_threshold , start_percent , end_percent , subsample_factor = 8 , offload_cache_diff = False , verbose = verbose )
487+ model .model_options ["transformer_options" ]["easycache" ] = LazyCacheHolder (reuse_threshold , start_percent , end_percent , subsample_factor = 8 , offload_cache_diff = False , verbose = verbose , output_channels = model . model . latent_format . latent_channels )
486488 model .add_wrapper_with_key (comfy .patcher_extension .WrappersMP .OUTER_SAMPLE , "lazycache" , easycache_sample_wrapper )
487489 model .add_wrapper_with_key (comfy .patcher_extension .WrappersMP .PREDICT_NOISE , "lazycache" , lazycache_predict_noise_wrapper )
488490 return io .NodeOutput (model )
0 commit comments