Skip to content

Commit 3cfe187

Browse files
authored
Cleanup ControlnetXS (#7701)
* update * update
1 parent 90250d9 commit 3cfe187

File tree

3 files changed

+61
-35
lines changed

3 files changed

+61
-35
lines changed

src/diffusers/models/controlnet_xs.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..utils import BaseOutput, is_torch_version, logging
2424
from ..utils.torch_utils import apply_freeu
25-
from .attention_processor import Attention, AttentionProcessor
25+
from .attention_processor import (
26+
ADDED_KV_ATTENTION_PROCESSORS,
27+
CROSS_ATTENTION_PROCESSORS,
28+
Attention,
29+
AttentionProcessor,
30+
AttnAddedKVProcessor,
31+
AttnProcessor,
32+
)
2633
from .controlnet import ControlNetConditioningEmbedding
2734
from .embeddings import TimestepEmbedding, Timesteps
2835
from .modeling_utils import ModelMixin
@@ -869,7 +876,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
869876

870877
return processors
871878

872-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
879+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
873880
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
874881
r"""
875882
Sets the attention processor to use to compute attention.
@@ -904,7 +911,23 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
904911
for name, module in self.named_children():
905912
fn_recursive_attn_processor(name, module, processor)
906913

907-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
914+
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
915+
def set_default_attn_processor(self):
916+
"""
917+
Disables custom attention processors and sets the default attention implementation.
918+
"""
919+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
920+
processor = AttnAddedKVProcessor()
921+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
922+
processor = AttnProcessor()
923+
else:
924+
raise ValueError(
925+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
926+
)
927+
928+
self.set_attn_processor(processor)
929+
930+
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
908931
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
909932
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
910933
@@ -929,7 +952,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
929952
setattr(upsample_block, "b1", b1)
930953
setattr(upsample_block, "b2", b2)
931954

932-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
955+
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
933956
def disable_freeu(self):
934957
"""Disables the FreeU mechanism."""
935958
freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -938,7 +961,7 @@ def disable_freeu(self):
938961
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
939962
setattr(upsample_block, k, None)
940963

941-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
964+
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
942965
def fuse_qkv_projections(self):
943966
"""
944967
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -962,7 +985,7 @@ def fuse_qkv_projections(self):
962985
if isinstance(module, Attention):
963986
module.fuse_projections(fuse=True)
964987

965-
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
988+
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
966989
def unfuse_qkv_projections(self):
967990
"""Disables the fused QKV projection if enabled.
968991

src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from ...schedulers import KarrasDiffusionSchedulers
4242
from ...utils import (
4343
USE_PEFT_BACKEND,
44-
deprecate,
4544
logging,
4645
replace_example_docstring,
4746
scale_lora_layers,
@@ -462,7 +461,6 @@ def check_inputs(
462461
prompt,
463462
prompt_2,
464463
image,
465-
callback_steps,
466464
negative_prompt=None,
467465
negative_prompt_2=None,
468466
prompt_embeds=None,
@@ -474,12 +472,6 @@ def check_inputs(
474472
control_guidance_end=1.0,
475473
callback_on_step_end_tensor_inputs=None,
476474
):
477-
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
478-
raise ValueError(
479-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
480-
f" {type(callback_steps)}."
481-
)
482-
483475
if callback_on_step_end_tensor_inputs is not None and not all(
484476
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
485477
):
@@ -749,7 +741,6 @@ def __call__(
749741
clip_skip: Optional[int] = None,
750742
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
751743
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
752-
**kwargs,
753744
):
754745
r"""
755746
The call function to the pipeline for generation.
@@ -878,30 +869,13 @@ def __call__(
878869
returned, otherwise a `tuple` is returned containing the output images.
879870
"""
880871

881-
callback = kwargs.pop("callback", None)
882-
callback_steps = kwargs.pop("callback_steps", None)
883-
884-
if callback is not None:
885-
deprecate(
886-
"callback",
887-
"1.0.0",
888-
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
889-
)
890-
if callback_steps is not None:
891-
deprecate(
892-
"callback_steps",
893-
"1.0.0",
894-
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
895-
)
896-
897872
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
898873

899874
# 1. Check inputs. Raise error if not correct
900875
self.check_inputs(
901876
prompt,
902877
prompt_2,
903878
image,
904-
callback_steps,
905879
negative_prompt,
906880
negative_prompt_2,
907881
prompt_embeds,
@@ -1089,9 +1063,6 @@ def __call__(
10891063
# call the callback, if provided
10901064
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
10911065
progress_bar.update()
1092-
if callback is not None and i % callback_steps == 0:
1093-
step_idx = i // getattr(self.scheduler, "order", 1)
1094-
callback(step_idx, t, latents)
10951066

10961067
# manually for max memory savings
10971068
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:

tests/pipelines/controlnet_xs/test_controlnetxs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@
6969
enable_full_determinism()
7070

7171

72+
def to_np(tensor):
73+
if isinstance(tensor, torch.Tensor):
74+
tensor = tensor.detach().cpu().numpy()
75+
76+
return tensor
77+
78+
7279
# Will be run via run_test_in_subprocess
7380
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
7481
error = None
@@ -299,6 +306,31 @@ def test_multi_vae(self):
299306

300307
assert out_vae_np.shape == out_np.shape
301308

309+
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
310+
def test_to_device(self):
311+
components = self.get_dummy_components()
312+
pipe = self.pipeline_class(**components)
313+
pipe.set_progress_bar_config(disable=None)
314+
315+
pipe.to("cpu")
316+
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
317+
model_devices = [
318+
component.device.type for component in pipe.components.values() if hasattr(component, "device")
319+
]
320+
self.assertTrue(all(device == "cpu" for device in model_devices))
321+
322+
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
323+
self.assertTrue(np.isnan(output_cpu).sum() == 0)
324+
325+
pipe.to("cuda")
326+
model_devices = [
327+
component.device.type for component in pipe.components.values() if hasattr(component, "device")
328+
]
329+
self.assertTrue(all(device == "cuda" for device in model_devices))
330+
331+
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
332+
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
333+
302334

303335
@slow
304336
@require_torch_gpu

0 commit comments

Comments
 (0)