diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py index 4bbe1dd4dc25..a4f9e61f37c7 100644 --- a/src/diffusers/models/controlnet_xs.py +++ b/src/diffusers/models/controlnet_xs.py @@ -22,7 +22,14 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging from ..utils.torch_utils import apply_freeu -from .attention_processor import Attention, AttentionProcessor +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) from .controlnet import ControlNetConditioningEmbedding from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -869,7 +876,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -904,7 +911,23 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -929,7 +952,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) - # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} @@ -938,7 +961,7 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) - # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) @@ -962,7 +985,7 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) - # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index ff270d20d11e..dd30ed70ae87 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -41,7 +41,6 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, - deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -462,7 +461,6 @@ def check_inputs( prompt, prompt_2, image, - callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -474,12 +472,6 @@ def check_inputs( control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -744,7 +736,6 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, ): r""" The call function to the pipeline for generation. @@ -873,22 +864,6 @@ def __call__( returned, otherwise a `tuple` is returned containing the output images. """ - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct @@ -896,7 +871,6 @@ def __call__( prompt, prompt_2, image, - callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1084,9 +1058,6 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 5ac78129ef34..42795807792b 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -69,6 +69,13 @@ enable_full_determinism() +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + # Will be run via run_test_in_subprocess def _test_stable_diffusion_compile(in_queue, out_queue, timeout): error = None @@ -299,6 +306,31 @@ def test_multi_vae(self): assert out_vae_np.shape == out_np.shape + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + @slow @require_torch_gpu