|
17 | 17 | from distutils.util import strtobool |
18 | 18 | from io import BytesIO, StringIO |
19 | 19 | from pathlib import Path |
20 | | -from typing import List, Optional, Union |
| 20 | +from typing import Callable, Dict, List, Optional, Union |
21 | 21 |
|
22 | 22 | import numpy as np |
23 | 23 | import PIL.Image |
|
58 | 58 | if is_torch_available(): |
59 | 59 | import torch |
60 | 60 |
|
| 61 | + # Set a backend environment variable for any extra module import required for a custom accelerator |
| 62 | + if "DIFFUSERS_TEST_BACKEND" in os.environ: |
| 63 | + backend = os.environ["DIFFUSERS_TEST_BACKEND"] |
| 64 | + try: |
| 65 | + _ = importlib.import_module(backend) |
| 66 | + except ModuleNotFoundError as e: |
| 67 | + raise ModuleNotFoundError( |
| 68 | + f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \ |
| 69 | + to enable a specified backend.):\n{e}" |
| 70 | + ) from e |
| 71 | + |
61 | 72 | if "DIFFUSERS_TEST_DEVICE" in os.environ: |
62 | 73 | torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] |
63 | 74 | try: |
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case): |
210 | 221 | ) |
211 | 222 |
|
212 | 223 |
|
| 224 | +# These decorators are for accelerator-specific behaviours that are not GPU-specific |
| 225 | +def require_torch_accelerator(test_case): |
| 226 | + """Decorator marking a test that requires an accelerator backend and PyTorch.""" |
| 227 | + return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( |
| 228 | + test_case |
| 229 | + ) |
| 230 | + |
| 231 | + |
| 232 | +def require_torch_accelerator_with_fp16(test_case): |
| 233 | + """Decorator marking a test that requires an accelerator with support for the FP16 data type.""" |
| 234 | + return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")( |
| 235 | + test_case |
| 236 | + ) |
| 237 | + |
| 238 | + |
| 239 | +def require_torch_accelerator_with_fp64(test_case): |
| 240 | + """Decorator marking a test that requires an accelerator with support for the FP64 data type.""" |
| 241 | + return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")( |
| 242 | + test_case |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +def require_torch_accelerator_with_training(test_case): |
| 247 | + """Decorator marking a test that requires an accelerator with support for training.""" |
| 248 | + return unittest.skipUnless( |
| 249 | + is_torch_available() and backend_supports_training(torch_device), |
| 250 | + "test requires accelerator with training support", |
| 251 | + )(test_case) |
| 252 | + |
| 253 | + |
213 | 254 | def skip_mps(test_case): |
214 | 255 | """Decorator marking a test to skip if torch_device is 'mps'""" |
215 | 256 | return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) |
@@ -766,3 +807,139 @@ def disable_full_determinism(): |
766 | 807 | os.environ["CUDA_LAUNCH_BLOCKING"] = "0" |
767 | 808 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" |
768 | 809 | torch.use_deterministic_algorithms(False) |
| 810 | + |
| 811 | + |
| 812 | +# Utils for custom and alternative accelerator devices |
| 813 | +def _is_torch_fp16_available(device): |
| 814 | + if not is_torch_available(): |
| 815 | + return False |
| 816 | + |
| 817 | + import torch |
| 818 | + |
| 819 | + device = torch.device(device) |
| 820 | + |
| 821 | + try: |
| 822 | + x = torch.zeros((2, 2), dtype=torch.float16).to(device) |
| 823 | + _ = x @ x |
| 824 | + except Exception as e: |
| 825 | + if device.type == "cuda": |
| 826 | + raise ValueError( |
| 827 | + f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" |
| 828 | + ) |
| 829 | + |
| 830 | + return False |
| 831 | + |
| 832 | + |
| 833 | +def _is_torch_fp64_available(device): |
| 834 | + if not is_torch_available(): |
| 835 | + return False |
| 836 | + |
| 837 | + import torch |
| 838 | + |
| 839 | + try: |
| 840 | + x = torch.zeros((2, 2), dtype=torch.float64).to(device) |
| 841 | + _ = x @ x |
| 842 | + except Exception as e: |
| 843 | + if device.type == "cuda": |
| 844 | + raise ValueError( |
| 845 | + f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" |
| 846 | + ) |
| 847 | + |
| 848 | + return False |
| 849 | + |
| 850 | + |
| 851 | +# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch |
| 852 | +if is_torch_available(): |
| 853 | + # Behaviour flags |
| 854 | + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} |
| 855 | + |
| 856 | + # Function definitions |
| 857 | + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} |
| 858 | + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} |
| 859 | + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} |
| 860 | + |
| 861 | + |
| 862 | +# This dispatches a defined function according to the accelerator from the function definitions. |
| 863 | +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): |
| 864 | + if device not in dispatch_table: |
| 865 | + return dispatch_table["default"](*args, **kwargs) |
| 866 | + |
| 867 | + fn = dispatch_table[device] |
| 868 | + |
| 869 | + # Some device agnostic functions return values. Need to guard against 'None' instead at |
| 870 | + # user level |
| 871 | + if fn is None: |
| 872 | + return None |
| 873 | + |
| 874 | + return fn(*args, **kwargs) |
| 875 | + |
| 876 | + |
| 877 | +# These are callables which automatically dispatch the function specific to the accelerator |
| 878 | +def backend_manual_seed(device: str, seed: int): |
| 879 | + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) |
| 880 | + |
| 881 | + |
| 882 | +def backend_empty_cache(device: str): |
| 883 | + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) |
| 884 | + |
| 885 | + |
| 886 | +def backend_device_count(device: str): |
| 887 | + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) |
| 888 | + |
| 889 | + |
| 890 | +# These are callables which return boolean behaviour flags and can be used to specify some |
| 891 | +# device agnostic alternative where the feature is unsupported. |
| 892 | +def backend_supports_training(device: str): |
| 893 | + if not is_torch_available(): |
| 894 | + return False |
| 895 | + |
| 896 | + if device not in BACKEND_SUPPORTS_TRAINING: |
| 897 | + device = "default" |
| 898 | + |
| 899 | + return BACKEND_SUPPORTS_TRAINING[device] |
| 900 | + |
| 901 | + |
| 902 | +# Guard for when Torch is not available |
| 903 | +if is_torch_available(): |
| 904 | + # Update device function dict mapping |
| 905 | + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): |
| 906 | + try: |
| 907 | + # Try to import the function directly |
| 908 | + spec_fn = getattr(device_spec_module, attribute_name) |
| 909 | + device_fn_dict[torch_device] = spec_fn |
| 910 | + except AttributeError as e: |
| 911 | + # If the function doesn't exist, and there is no default, throw an error |
| 912 | + if "default" not in device_fn_dict: |
| 913 | + raise AttributeError( |
| 914 | + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." |
| 915 | + ) from e |
| 916 | + |
| 917 | + if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ: |
| 918 | + device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"] |
| 919 | + if not Path(device_spec_path).is_file(): |
| 920 | + raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}") |
| 921 | + |
| 922 | + try: |
| 923 | + import_name = device_spec_path[: device_spec_path.index(".py")] |
| 924 | + except ValueError as e: |
| 925 | + raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e |
| 926 | + |
| 927 | + device_spec_module = importlib.import_module(import_name) |
| 928 | + |
| 929 | + try: |
| 930 | + device_name = device_spec_module.DEVICE_NAME |
| 931 | + except AttributeError: |
| 932 | + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") |
| 933 | + |
| 934 | + if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name: |
| 935 | + msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" |
| 936 | + msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name." |
| 937 | + raise ValueError(msg) |
| 938 | + |
| 939 | + torch_device = device_name |
| 940 | + |
| 941 | + # Add one entry here for each `BACKEND_*` dictionary. |
| 942 | + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") |
| 943 | + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") |
| 944 | + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") |
| 945 | + update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") |
0 commit comments