|
8 | 8 | import tempfile |
9 | 9 | import unittest |
10 | 10 | import uuid |
11 | | -from typing import Callable, Union |
| 11 | +from typing import Any, Callable, Dict, Union |
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 | import PIL.Image |
@@ -85,46 +85,51 @@ def test_pipeline_signature(self): |
85 | 85 | def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): |
86 | 86 | return torch.zeros((2, 1, cross_attention_dim), device=torch_device) |
87 | 87 |
|
| 88 | + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): |
| 89 | + inputs["output_type"] = "np" |
| 90 | + inputs["return_dict"] = False |
| 91 | + return inputs |
| 92 | + |
88 | 93 | def test_ip_adapter(self, expected_max_diff: float = 1e-4): |
89 | 94 | components = self.get_dummy_components() |
90 | 95 | pipe = self.pipeline_class(**components).to(torch_device) |
91 | 96 | pipe.set_progress_bar_config(disable=None) |
92 | 97 | cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) |
93 | 98 |
|
94 | 99 | # forward pass without ip adapter |
95 | | - inputs = self.get_dummy_inputs(torch_device) |
96 | | - output_without_adapter = pipe(**inputs, return_dict=False)[0] |
| 100 | + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) |
| 101 | + output_without_adapter = pipe(**inputs)[0] |
97 | 102 |
|
98 | 103 | adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) |
99 | 104 | adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) |
100 | 105 |
|
101 | 106 | pipe.unet._load_ip_adapter_weights(adapter_state_dict_1) |
102 | 107 |
|
103 | 108 | # forward pass with single ip adapter, but scale=0 which should have no effect |
104 | | - inputs = self.get_dummy_inputs(torch_device) |
| 109 | + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) |
105 | 110 | inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] |
106 | 111 | pipe.set_ip_adapter_scale(0.0) |
107 | | - output_without_adapter_scale = pipe(**inputs, return_dict=False)[0] |
| 112 | + output_without_adapter_scale = pipe(**inputs)[0] |
108 | 113 |
|
109 | 114 | # forward pass with single ip adapter, but with scale of adapter weights |
110 | | - inputs = self.get_dummy_inputs(torch_device) |
| 115 | + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) |
111 | 116 | inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] |
112 | 117 | pipe.set_ip_adapter_scale(1.0) |
113 | | - output_with_adapter_scale = pipe(**inputs, return_dict=False)[0] |
| 118 | + output_with_adapter_scale = pipe(**inputs)[0] |
114 | 119 |
|
115 | 120 | pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) |
116 | 121 |
|
117 | 122 | # forward pass with multi ip adapter, but scale=0 which should have no effect |
118 | | - inputs = self.get_dummy_inputs(torch_device) |
| 123 | + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) |
119 | 124 | inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 |
120 | 125 | pipe.set_ip_adapter_scale([0.0, 0.0]) |
121 | | - output_without_multi_adapter_scale = pipe(**inputs, return_dict=False)[0] |
| 126 | + output_without_multi_adapter_scale = pipe(**inputs)[0] |
122 | 127 |
|
123 | 128 | # forward pass with multi ip adapter, but with scale of adapter weights |
124 | | - inputs = self.get_dummy_inputs(torch_device) |
| 129 | + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) |
125 | 130 | inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 |
126 | 131 | pipe.set_ip_adapter_scale([0.5, 0.5]) |
127 | | - output_with_multi_adapter_scale = pipe(**inputs, return_dict=False)[0] |
| 132 | + output_with_multi_adapter_scale = pipe(**inputs)[0] |
128 | 133 |
|
129 | 134 | max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() |
130 | 135 | max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() |
|
0 commit comments