Skip to content

Commit 4e9d60a

Browse files
committed
fix animatediff based tests
1 parent ceabed5 commit 4e9d60a

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,7 @@ def __call__(
905905
if self.do_classifier_free_guidance:
906906
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
907907

908-
if ip_adapter_image is not None:
908+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
909909
image_embeds = self.prepare_ip_adapter_image_embeds(
910910
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
911911
)

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def create_ip_adapter_state_dict(model):
6262
key_id = 1
6363

6464
for name in model.attn_processors.keys():
65-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
65+
cross_attention_dim = (
66+
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
67+
)
68+
6669
if name.startswith("mid_block"):
6770
hidden_size = model.config.block_out_channels[-1]
6871
elif name.startswith("up_blocks"):
@@ -71,6 +74,7 @@ def create_ip_adapter_state_dict(model):
7174
elif name.startswith("down_blocks"):
7275
block_id = int(name[len("down_blocks.")])
7376
hidden_size = model.config.block_out_channels[block_id]
77+
7478
if cross_attention_dim is not None:
7579
sd = IPAdapterAttnProcessor(
7680
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0

tests/pipelines/test_pipelines_common.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tempfile
99
import unittest
1010
import uuid
11-
from typing import Callable, Union
11+
from typing import Any, Callable, Dict, Union
1212

1313
import numpy as np
1414
import PIL.Image
@@ -85,46 +85,51 @@ def test_pipeline_signature(self):
8585
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
8686
return torch.zeros((2, 1, cross_attention_dim), device=torch_device)
8787

88+
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
89+
inputs["output_type"] = "np"
90+
inputs["return_dict"] = False
91+
return inputs
92+
8893
def test_ip_adapter(self, expected_max_diff: float = 1e-4):
8994
components = self.get_dummy_components()
9095
pipe = self.pipeline_class(**components).to(torch_device)
9196
pipe.set_progress_bar_config(disable=None)
9297
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
9398

9499
# forward pass without ip adapter
95-
inputs = self.get_dummy_inputs(torch_device)
96-
output_without_adapter = pipe(**inputs, return_dict=False)[0]
100+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
101+
output_without_adapter = pipe(**inputs)[0]
97102

98103
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
99104
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
100105

101106
pipe.unet._load_ip_adapter_weights(adapter_state_dict_1)
102107

103108
# forward pass with single ip adapter, but scale=0 which should have no effect
104-
inputs = self.get_dummy_inputs(torch_device)
109+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
105110
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
106111
pipe.set_ip_adapter_scale(0.0)
107-
output_without_adapter_scale = pipe(**inputs, return_dict=False)[0]
112+
output_without_adapter_scale = pipe(**inputs)[0]
108113

109114
# forward pass with single ip adapter, but with scale of adapter weights
110-
inputs = self.get_dummy_inputs(torch_device)
115+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
111116
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
112117
pipe.set_ip_adapter_scale(1.0)
113-
output_with_adapter_scale = pipe(**inputs, return_dict=False)[0]
118+
output_with_adapter_scale = pipe(**inputs)[0]
114119

115120
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
116121

117122
# forward pass with multi ip adapter, but scale=0 which should have no effect
118-
inputs = self.get_dummy_inputs(torch_device)
123+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
119124
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
120125
pipe.set_ip_adapter_scale([0.0, 0.0])
121-
output_without_multi_adapter_scale = pipe(**inputs, return_dict=False)[0]
126+
output_without_multi_adapter_scale = pipe(**inputs)[0]
122127

123128
# forward pass with multi ip adapter, but with scale of adapter weights
124-
inputs = self.get_dummy_inputs(torch_device)
129+
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
125130
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
126131
pipe.set_ip_adapter_scale([0.5, 0.5])
127-
output_with_multi_adapter_scale = pipe(**inputs, return_dict=False)[0]
132+
output_with_multi_adapter_scale = pipe(**inputs)[0]
128133

129134
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
130135
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()

0 commit comments

Comments
 (0)