diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index b5ac6cc3012f..2c611aa2c033 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -613,6 +613,9 @@ def _assign_components_to_devices( def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): + # TODO: seperate out different device_map methods when it gets to it. + if device_map != "balanced": + return device_map # To avoid circular import problem. from diffusers import pipelines diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 22efaccec140..d231989973e4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -108,7 +108,7 @@ for library in LOADABLE_CLASSES: LIBRARIES.append(library) -SUPPORTED_DEVICE_MAP = ["balanced"] +SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] logger = logging.get_logger(__name__) @@ -988,12 +988,15 @@ def load_module(name, value): _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): # 7.1 device_map shenanigans - if final_device_map is not None and len(final_device_map) > 0: - component_device = final_device_map.get(name, None) - if component_device is not None: - current_device_map = {"": component_device} - else: - current_device_map = None + if final_device_map is not None: + if isinstance(final_device_map, dict) and len(final_device_map) > 0: + component_device = final_device_map.get(name, None) + if component_device is not None: + current_device_map = {"": component_device} + else: + current_device_map = None + elif isinstance(final_device_map, str): + current_device_map = final_device_map # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names class_name = class_name[4:] if class_name.startswith("Flax") else class_name diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index dd54cb2b9186..5bc708a60c29 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -15,6 +15,7 @@ PyTorch utilities: Utilities related to PyTorch """ +import functools from typing import List, Optional, Tuple, Union from . import logging @@ -168,6 +169,7 @@ def get_torch_cuda_device_capability(): return None +@functools.lru_cache def get_device(): if torch.cuda.is_available(): return "cuda" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 387eb6a614f9..ed6a56c5faf5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2339,6 +2339,29 @@ def test_torch_dtype_dict(self): f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", ) + @require_torch_accelerator + def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = torch.manual_seed(0) + out = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device) + for component in loaded_pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + inputs["generator"] = torch.manual_seed(0) + loaded_out = loaded_pipe(**inputs)[0] + max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() + self.assertLess(max_diff, expected_max_difference) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase):