Skip to content

Commit f427345

Browse files
authored
Device agnostic testing (#5612)
* utils and test modifications to enable device agnostic testing * device for manual seed in unet1d * fix generator condition in vae test * consistency changes to testing * make style * add device agnostic testing changes to source and one model test * make dtype check fns private, log cuda fp16 case * remove dtype checks from import utils, move to testing_utils * adding tests for most model classes and one pipeline * fix vae import
1 parent 6e22133 commit f427345

File tree

11 files changed

+306
-67
lines changed

11 files changed

+306
-67
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from distutils.util import strtobool
1818
from io import BytesIO, StringIO
1919
from pathlib import Path
20-
from typing import List, Optional, Union
20+
from typing import Callable, Dict, List, Optional, Union
2121

2222
import numpy as np
2323
import PIL.Image
@@ -58,6 +58,17 @@
5858
if is_torch_available():
5959
import torch
6060

61+
# Set a backend environment variable for any extra module import required for a custom accelerator
62+
if "DIFFUSERS_TEST_BACKEND" in os.environ:
63+
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
64+
try:
65+
_ = importlib.import_module(backend)
66+
except ModuleNotFoundError as e:
67+
raise ModuleNotFoundError(
68+
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
69+
to enable a specified backend.):\n{e}"
70+
) from e
71+
6172
if "DIFFUSERS_TEST_DEVICE" in os.environ:
6273
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
6374
try:
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
210221
)
211222

212223

224+
# These decorators are for accelerator-specific behaviours that are not GPU-specific
225+
def require_torch_accelerator(test_case):
226+
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
227+
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
228+
test_case
229+
)
230+
231+
232+
def require_torch_accelerator_with_fp16(test_case):
233+
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
234+
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
235+
test_case
236+
)
237+
238+
239+
def require_torch_accelerator_with_fp64(test_case):
240+
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
241+
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
242+
test_case
243+
)
244+
245+
246+
def require_torch_accelerator_with_training(test_case):
247+
"""Decorator marking a test that requires an accelerator with support for training."""
248+
return unittest.skipUnless(
249+
is_torch_available() and backend_supports_training(torch_device),
250+
"test requires accelerator with training support",
251+
)(test_case)
252+
253+
213254
def skip_mps(test_case):
214255
"""Decorator marking a test to skip if torch_device is 'mps'"""
215256
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
@@ -766,3 +807,139 @@ def disable_full_determinism():
766807
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
767808
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
768809
torch.use_deterministic_algorithms(False)
810+
811+
812+
# Utils for custom and alternative accelerator devices
813+
def _is_torch_fp16_available(device):
814+
if not is_torch_available():
815+
return False
816+
817+
import torch
818+
819+
device = torch.device(device)
820+
821+
try:
822+
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
823+
_ = x @ x
824+
except Exception as e:
825+
if device.type == "cuda":
826+
raise ValueError(
827+
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
828+
)
829+
830+
return False
831+
832+
833+
def _is_torch_fp64_available(device):
834+
if not is_torch_available():
835+
return False
836+
837+
import torch
838+
839+
try:
840+
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
841+
_ = x @ x
842+
except Exception as e:
843+
if device.type == "cuda":
844+
raise ValueError(
845+
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
846+
)
847+
848+
return False
849+
850+
851+
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
852+
if is_torch_available():
853+
# Behaviour flags
854+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
855+
856+
# Function definitions
857+
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
858+
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
859+
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
860+
861+
862+
# This dispatches a defined function according to the accelerator from the function definitions.
863+
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
864+
if device not in dispatch_table:
865+
return dispatch_table["default"](*args, **kwargs)
866+
867+
fn = dispatch_table[device]
868+
869+
# Some device agnostic functions return values. Need to guard against 'None' instead at
870+
# user level
871+
if fn is None:
872+
return None
873+
874+
return fn(*args, **kwargs)
875+
876+
877+
# These are callables which automatically dispatch the function specific to the accelerator
878+
def backend_manual_seed(device: str, seed: int):
879+
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
880+
881+
882+
def backend_empty_cache(device: str):
883+
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
884+
885+
886+
def backend_device_count(device: str):
887+
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
888+
889+
890+
# These are callables which return boolean behaviour flags and can be used to specify some
891+
# device agnostic alternative where the feature is unsupported.
892+
def backend_supports_training(device: str):
893+
if not is_torch_available():
894+
return False
895+
896+
if device not in BACKEND_SUPPORTS_TRAINING:
897+
device = "default"
898+
899+
return BACKEND_SUPPORTS_TRAINING[device]
900+
901+
902+
# Guard for when Torch is not available
903+
if is_torch_available():
904+
# Update device function dict mapping
905+
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
906+
try:
907+
# Try to import the function directly
908+
spec_fn = getattr(device_spec_module, attribute_name)
909+
device_fn_dict[torch_device] = spec_fn
910+
except AttributeError as e:
911+
# If the function doesn't exist, and there is no default, throw an error
912+
if "default" not in device_fn_dict:
913+
raise AttributeError(
914+
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
915+
) from e
916+
917+
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
918+
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
919+
if not Path(device_spec_path).is_file():
920+
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
921+
922+
try:
923+
import_name = device_spec_path[: device_spec_path.index(".py")]
924+
except ValueError as e:
925+
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
926+
927+
device_spec_module = importlib.import_module(import_name)
928+
929+
try:
930+
device_name = device_spec_module.DEVICE_NAME
931+
except AttributeError:
932+
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
933+
934+
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
935+
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
936+
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
937+
raise ValueError(msg)
938+
939+
torch_device = device_name
940+
941+
# Add one entry here for each `BACKEND_*` dictionary.
942+
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
943+
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
944+
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
945+
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")

tests/models/test_layers_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from diffusers.models.lora import LoRACompatibleLinear
2626
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
2727
from diffusers.models.transformer_2d import Transformer2DModel
28-
from diffusers.utils.testing_utils import torch_device
28+
from diffusers.utils.testing_utils import (
29+
backend_manual_seed,
30+
require_torch_accelerator_with_fp64,
31+
torch_device,
32+
)
2933

3034

3135
class EmbeddingsTests(unittest.TestCase):
@@ -315,8 +319,7 @@ def test_restnet_with_kernel_sde_vp(self):
315319
class Transformer2DModelTests(unittest.TestCase):
316320
def test_spatial_transformer_default(self):
317321
torch.manual_seed(0)
318-
if torch.cuda.is_available():
319-
torch.cuda.manual_seed_all(0)
322+
backend_manual_seed(torch_device, 0)
320323

321324
sample = torch.randn(1, 32, 64, 64).to(torch_device)
322325
spatial_transformer_block = Transformer2DModel(
@@ -339,8 +342,7 @@ def test_spatial_transformer_default(self):
339342

340343
def test_spatial_transformer_cross_attention_dim(self):
341344
torch.manual_seed(0)
342-
if torch.cuda.is_available():
343-
torch.cuda.manual_seed_all(0)
345+
backend_manual_seed(torch_device, 0)
344346

345347
sample = torch.randn(1, 64, 64, 64).to(torch_device)
346348
spatial_transformer_block = Transformer2DModel(
@@ -363,8 +365,7 @@ def test_spatial_transformer_cross_attention_dim(self):
363365

364366
def test_spatial_transformer_timestep(self):
365367
torch.manual_seed(0)
366-
if torch.cuda.is_available():
367-
torch.cuda.manual_seed_all(0)
368+
backend_manual_seed(torch_device, 0)
368369

369370
num_embeds_ada_norm = 5
370371

@@ -401,8 +402,7 @@ def test_spatial_transformer_timestep(self):
401402

402403
def test_spatial_transformer_dropout(self):
403404
torch.manual_seed(0)
404-
if torch.cuda.is_available():
405-
torch.cuda.manual_seed_all(0)
405+
backend_manual_seed(torch_device, 0)
406406

407407
sample = torch.randn(1, 32, 64, 64).to(torch_device)
408408
spatial_transformer_block = (
@@ -427,11 +427,10 @@ def test_spatial_transformer_dropout(self):
427427
)
428428
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
429429

430-
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
430+
@require_torch_accelerator_with_fp64
431431
def test_spatial_transformer_discrete(self):
432432
torch.manual_seed(0)
433-
if torch.cuda.is_available():
434-
torch.cuda.manual_seed_all(0)
433+
backend_manual_seed(torch_device, 0)
435434

436435
num_embed = 5
437436

tests/models/test_modeling_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CaptureLogger,
3636
require_python39_or_higher,
3737
require_torch_2,
38+
require_torch_accelerator_with_training,
3839
require_torch_gpu,
3940
run_test_in_subprocess,
4041
torch_device,
@@ -536,7 +537,7 @@ def test_model_from_pretrained(self):
536537

537538
self.assertEqual(output_1.shape, output_2.shape)
538539

539-
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
540+
@require_torch_accelerator_with_training
540541
def test_training(self):
541542
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
542543

@@ -553,7 +554,7 @@ def test_training(self):
553554
loss = torch.nn.functional.mse_loss(output, noise)
554555
loss.backward()
555556

556-
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
557+
@require_torch_accelerator_with_training
557558
def test_ema_training(self):
558559
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
559560

@@ -624,7 +625,7 @@ def recursive_check(tuple_object, dict_object):
624625

625626
recursive_check(outputs_tuple, outputs_dict)
626627

627-
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
628+
@require_torch_accelerator_with_training
628629
def test_enable_disable_gradient_checkpointing(self):
629630
if not self.model_class._supports_gradient_checkpointing:
630631
return # Skip test if model does not support gradient checkpointing

tests/models/test_models_prior.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
from parameterized import parameterized
2222

2323
from diffusers import PriorTransformer
24-
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
floats_tensor,
28+
slow,
29+
torch_all_close,
30+
torch_device,
31+
)
2532

2633
from .test_modeling_common import ModelTesterMixin
2734

@@ -157,7 +164,7 @@ def tearDown(self):
157164
# clean up the VRAM after each test
158165
super().tearDown()
159166
gc.collect()
160-
torch.cuda.empty_cache()
167+
backend_empty_cache()
161168

162169
@parameterized.expand(
163170
[

tests/models/test_models_unet_1d.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import torch
1919

2020
from diffusers import UNet1DModel
21-
from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
21+
from diffusers.utils.testing_utils import (
22+
backend_manual_seed,
23+
floats_tensor,
24+
slow,
25+
torch_device,
26+
)
2227

2328
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
2429

@@ -103,8 +108,7 @@ def test_from_pretrained_hub(self):
103108
def test_output_pretrained(self):
104109
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
105110
torch.manual_seed(0)
106-
if torch.cuda.is_available():
107-
torch.cuda.manual_seed_all(0)
111+
backend_manual_seed(torch_device, 0)
108112

109113
num_features = model.config.in_channels
110114
seq_len = 16
@@ -244,8 +248,7 @@ def test_output_pretrained(self):
244248
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
245249
)
246250
torch.manual_seed(0)
247-
if torch.cuda.is_available():
248-
torch.cuda.manual_seed_all(0)
251+
backend_manual_seed(torch_device, 0)
249252

250253
num_features = value_function.config.in_channels
251254
seq_len = 14

tests/models/test_models_unet_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
2626
floats_tensor,
27+
require_torch_accelerator,
2728
slow,
2829
torch_all_close,
2930
torch_device,
@@ -153,15 +154,15 @@ def test_from_pretrained_hub(self):
153154

154155
assert image is not None, "Make sure output is not None"
155156

156-
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
157+
@require_torch_accelerator
157158
def test_from_pretrained_accelerate(self):
158159
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
159160
model.to(torch_device)
160161
image = model(**self.dummy_input).sample
161162

162163
assert image is not None, "Make sure output is not None"
163164

164-
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
165+
@require_torch_accelerator
165166
def test_from_pretrained_accelerate_wont_change_results(self):
166167
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
167168
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)

0 commit comments

Comments
 (0)