-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Closed
Labels
Potential BugUser is reporting a bug. This should be tested.User is reporting a bug. This should be tested.
Description
Custom Node Testing
- I have tried disabling custom nodes and the issue persists (see how to disable custom nodes if you need help)
Expected Behavior
Two images are generated (batch_size=2), with some sampling steps skipped (EasyCache).
Actual Behavior
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0
Steps to Reproduce
easy-cache-batch-size-error.json
Very simple SDXL workflow. Uses Empty Latent with batch_size=2 and EasyCache with default values.
Debug Logs
Checkpoint files will always be loaded safely.
Total VRAM 12282 MB, total RAM 32694 MB
pytorch version: 2.8.0+cu128
Set vram state to: NORMAL_VRAM
Device: cuda:0 NVIDIA GeForce RTX 4070 : cudaMallocAsync
Using pytorch attention
Python version: 3.12.9 (main, Feb 12 2025, 14:52:31) [MSC v.1942 64 bit (AMD64)]
ComfyUI version: 0.3.65
ComfyUI frontend version: 1.27.10
[Prompt Server] web root: C:\Dev\ComfyUI\.venv\Lib\site-packages\comfyui_frontend_package\static
Skipping loading of custom nodes
Context impl SQLiteImpl.
Will assume non-transactional DDL.
No target revision found.
Starting server
To see the GUI go to: http://127.0.0.1:8188
got prompt
model weight dtype torch.float16, manual cast: None
model_type V_PREDICTION
Using pytorch attention in VAE
Using pytorch attention in VAE
VAE load device: cuda:0, offload device: cpu, dtype: torch.bfloat16
Requested to load SDXLClipModel
loaded completely 9.5367431640625e+25 1560.802734375 True
CLIP/text encoder model load device: cuda:0, offload device: cpu, current: cuda:0, dtype: torch.float16
EasyCache enabled - threshold: 0.2, start_percent: 0.15, end_percent: 0.95
Requested to load SDXL
loaded completely 8057.872177886963 4897.0483474731445 True
25%|████████████████▎ | 5/20 [00:03<00:10, 1.44it/s]
EasyCache - skipped 1/20 steps (1.05x speedup).
!!! Exception during processing !!! The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0
Traceback (most recent call last):
File "C:\Dev\ComfyUI\execution.py", line 496, in execute
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\execution.py", line 315, in get_output_data
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\execution.py", line 289, in _async_map_node_over_list
await process_inputs(input_dict, i)
File "C:\Dev\ComfyUI\execution.py", line 277, in process_inputs
result = f(**inputs)
^^^^^^^^^^^
File "C:\Dev\ComfyUI\nodes.py", line 1525, in sample
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\nodes.py", line 1492, in common_ksampler
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\sample.py", line 45, in sample
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 1161, in sample
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 1051, in sample
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 1036, in sample
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 113, in execute
return self.wrappers[self.idx](self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy_extras\nodes_easycache.py", line 156, in easycache_sample_wrapper
return executor(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 105, in __call__
return new_executor.execute(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 1004, in outer_sample
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 987, in inner_sample
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 759, in sample
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\.venv\Lib\site-packages\torch\utils\_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\k_diffusion\sampling.py", line 199, in sample_euler
denoised = model(x, sigma_hat * s_in, **extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 408, in __call__
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 960, in __call__
return self.outer_predict_noise(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 967, in outer_predict_noise
).execute(x, timestep, model_options, seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 970, in predict_noise
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 388, in sampling_function
out = calc_cond_batch(model, conds, x, timestep, model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 206, in calc_cond_batch
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 214, in _calc_cond_batch_outer
return executor.execute(model, conds, x_in, timestep, model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 113, in execute
return self.wrappers[self.idx](self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy_extras\nodes_easycache.py", line 142, in easycache_calc_cond_batch_wrapper
return executor(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 105, in __call__
return new_executor.execute(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\samplers.py", line 333, in _calc_cond_batch
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\model_base.py", line 161, in apply_model
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 112, in execute
return self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\model_base.py", line 200, in _apply_model
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\.venv\Lib\site-packages\torch\nn\modules\module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\ldm\modules\diffusionmodules\openaimodel.py", line 831, in forward
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy\patcher_extension.py", line 113, in execute
return self.wrappers[self.idx](self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy_extras\nodes_easycache.py", line 52, in easycache_forward_wrapper
return easycache.apply_cache_diff(x, uuids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Dev\ComfyUI\comfy_extras\nodes_easycache.py", line 266, in apply_cache_diff
x += self.uuid_cache_diffs[uuid].to(x.device)
RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0Other
Tries to apply cached results in nodes_easycache.py:266
x += self.uuid_cache_diffs[uuid].to(x.device)
From what I can see
- for batch_size=1 case
x.shape[0]==2anddiff.shape[0]==1which can be added with broadcast - for batch_size=n case
x.shape[0]==2*nanddiff.shape[0]==nwhich fails
Metadata
Metadata
Assignees
Labels
Potential BugUser is reporting a bug. This should be tested.User is reporting a bug. This should be tested.