23
23
from diffusers .utils import (
24
24
floats_tensor ,
25
25
logging ,
26
- torch_all_close ,
27
26
torch_device ,
28
27
)
29
28
from diffusers .utils .import_utils import is_xformers_available
@@ -120,47 +119,6 @@ def test_xformers_enable_works(self):
120
119
== "XFormersAttnProcessor"
121
120
), "xformers is not enabled"
122
121
123
- @unittest .skipIf (torch_device == "mps" , "Gradient checkpointing skipped on MPS" )
124
- def test_gradient_checkpointing (self ):
125
- # enable deterministic behavior for gradient checkpointing
126
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
127
- model = self .model_class (** init_dict )
128
- model .to (torch_device )
129
-
130
- assert not model .is_gradient_checkpointing and model .training
131
-
132
- out = model (** inputs_dict ).sample
133
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
134
- # we won't calculate the loss and rather backprop on out.sum()
135
- model .zero_grad ()
136
-
137
- labels = torch .randn_like (out )
138
- loss = (out - labels ).mean ()
139
- loss .backward ()
140
-
141
- # re-instantiate the model now enabling gradient checkpointing
142
- model_2 = self .model_class (** init_dict )
143
- # clone model
144
- model_2 .load_state_dict (model .state_dict ())
145
- model_2 .to (torch_device )
146
- model_2 .enable_gradient_checkpointing ()
147
-
148
- assert model_2 .is_gradient_checkpointing and model_2 .training
149
-
150
- out_2 = model_2 (** inputs_dict ).sample
151
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
152
- # we won't calculate the loss and rather backprop on out.sum()
153
- model_2 .zero_grad ()
154
- loss_2 = (out_2 - labels ).mean ()
155
- loss_2 .backward ()
156
-
157
- # compare the output and parameters gradients
158
- self .assertTrue ((loss - loss_2 ).abs () < 1e-4 )
159
- named_params = dict (model .named_parameters ())
160
- named_params_2 = dict (model_2 .named_parameters ())
161
- for name , param in named_params .items ():
162
- self .assertTrue (torch_all_close (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-4 ))
163
-
164
122
# Overriding because `block_out_channels` needs to be different for this model.
165
123
def test_forward_with_norm_groups (self ):
166
124
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
@@ -209,44 +167,6 @@ def test_determinism(self):
209
167
max_diff = np .amax (np .abs (out_1 - out_2 ))
210
168
self .assertLessEqual (max_diff , 1e-5 )
211
169
212
- def test_model_with_attention_head_dim_tuple (self ):
213
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
214
-
215
- init_dict ["attention_head_dim" ] = (8 , 16 , 16 , 16 )
216
-
217
- model = self .model_class (** init_dict )
218
- model .to (torch_device )
219
- model .eval ()
220
-
221
- with torch .no_grad ():
222
- output = model (** inputs_dict )
223
-
224
- if isinstance (output , dict ):
225
- output = output .sample
226
-
227
- self .assertIsNotNone (output )
228
- expected_shape = inputs_dict ["sample" ].shape
229
- self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
230
-
231
- def test_model_with_use_linear_projection (self ):
232
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
233
-
234
- init_dict ["use_linear_projection" ] = True
235
-
236
- model = self .model_class (** init_dict )
237
- model .to (torch_device )
238
- model .eval ()
239
-
240
- with torch .no_grad ():
241
- output = model (** inputs_dict )
242
-
243
- if isinstance (output , dict ):
244
- output = output .sample
245
-
246
- self .assertIsNotNone (output )
247
- expected_shape = inputs_dict ["sample" ].shape
248
- self .assertEqual (output .shape , expected_shape , "Input and output shapes do not match" )
249
-
250
170
def test_model_attention_slicing (self ):
251
171
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
252
172
@@ -271,81 +191,6 @@ def test_model_attention_slicing(self):
271
191
output = model (** inputs_dict )
272
192
assert output is not None
273
193
274
- def test_model_slicable_head_dim (self ):
275
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
276
-
277
- init_dict ["attention_head_dim" ] = (8 , 16 , 16 , 16 )
278
-
279
- model = self .model_class (** init_dict )
280
-
281
- def check_slicable_dim_attr (module : torch .nn .Module ):
282
- if hasattr (module , "set_attention_slice" ):
283
- assert isinstance (module .sliceable_head_dim , int )
284
-
285
- for child in module .children ():
286
- check_slicable_dim_attr (child )
287
-
288
- # retrieve number of attention layers
289
- for module in model .children ():
290
- check_slicable_dim_attr (module )
291
-
292
- def test_special_attn_proc (self ):
293
- class AttnEasyProc (torch .nn .Module ):
294
- def __init__ (self , num ):
295
- super ().__init__ ()
296
- self .weight = torch .nn .Parameter (torch .tensor (num ))
297
- self .is_run = False
298
- self .number = 0
299
- self .counter = 0
300
-
301
- def __call__ (self , attn , hidden_states , encoder_hidden_states = None , attention_mask = None , number = None ):
302
- batch_size , sequence_length , _ = hidden_states .shape
303
- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
304
-
305
- query = attn .to_q (hidden_states )
306
-
307
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
308
- key = attn .to_k (encoder_hidden_states )
309
- value = attn .to_v (encoder_hidden_states )
310
-
311
- query = attn .head_to_batch_dim (query )
312
- key = attn .head_to_batch_dim (key )
313
- value = attn .head_to_batch_dim (value )
314
-
315
- attention_probs = attn .get_attention_scores (query , key , attention_mask )
316
- hidden_states = torch .bmm (attention_probs , value )
317
- hidden_states = attn .batch_to_head_dim (hidden_states )
318
-
319
- # linear proj
320
- hidden_states = attn .to_out [0 ](hidden_states )
321
- # dropout
322
- hidden_states = attn .to_out [1 ](hidden_states )
323
-
324
- hidden_states += self .weight
325
-
326
- self .is_run = True
327
- self .counter += 1
328
- self .number = number
329
-
330
- return hidden_states
331
-
332
- # enable deterministic behavior for gradient checkpointing
333
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
334
-
335
- init_dict ["attention_head_dim" ] = (8 , 16 , 16 , 16 )
336
-
337
- model = self .model_class (** init_dict )
338
- model .to (torch_device )
339
-
340
- processor = AttnEasyProc (5.0 )
341
-
342
- model .set_attn_processor (processor )
343
- model (** inputs_dict , cross_attention_kwargs = {"number" : 123 }).sample
344
-
345
- assert processor .counter == 12
346
- assert processor .is_run
347
- assert processor .number == 123
348
-
349
194
# (`attn_processors`) needs to be implemented in this model for this test.
350
195
# def test_lora_processors(self):
351
196
0 commit comments