Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def load_ip_adapter(
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# load CLIP image encoer here if it has not been registered to the pipeline yet
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
Expand All @@ -141,12 +141,14 @@ def load_ip_adapter(
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])

# load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict)
Expand All @@ -155,3 +157,32 @@ def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale

def unload_ip_adapter(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

"""
Unloads the IP Adapter weights

Examples:

```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])

# remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])

# remove hidden encoder
self.unet.encoder_hid_proj = None
self.config.encoder_hid_dim_type = None

# restore original Unet attention processors layers
self.unet.set_default_attn_processor()
20 changes: 20 additions & 0 deletions tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
Expand Down Expand Up @@ -228,6 +229,25 @@ def test_text_to_image_full_face(self):

assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_unload(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)

pipeline.unload_ip_adapter()

assert getattr(pipeline, "image_encoder") is None
assert getattr(pipeline, "feature_extractor") is None
processors = [
isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0))
for name, attn_proc in pipeline.unet.attn_processors.items()
]
assert processors == [True] * len(processors)


@slow
@require_torch_gpu
Expand Down