From eba6f827b934bb662eb3d19b501e1180967f7875 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Aug 2025 11:37:22 +0530 Subject: [PATCH 1/5] feat: cuda device_map for pipelines. --- .../pipelines/pipeline_loading_utils.py | 3 +++ src/diffusers/pipelines/pipeline_utils.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index b5ac6cc3012f..2c611aa2c033 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -613,6 +613,9 @@ def _assign_components_to_devices( def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): + # TODO: seperate out different device_map methods when it gets to it. + if device_map != "balanced": + return device_map # To avoid circular import problem. from diffusers import pipelines diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 22efaccec140..a2811aa1e2dc 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -108,7 +108,8 @@ for library in LOADABLE_CLASSES: LIBRARIES.append(library) -SUPPORTED_DEVICE_MAP = ["balanced"] +# TODO: support single-device namings +SUPPORTED_DEVICE_MAP = ["balanced", "cuda"] logger = logging.get_logger(__name__) @@ -988,12 +989,15 @@ def load_module(name, value): _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): # 7.1 device_map shenanigans - if final_device_map is not None and len(final_device_map) > 0: - component_device = final_device_map.get(name, None) - if component_device is not None: - current_device_map = {"": component_device} - else: - current_device_map = None + if final_device_map is not None: + if isinstance(final_device_map, dict) and len(final_device_map) > 0: + component_device = final_device_map.get(name, None) + if component_device is not None: + current_device_map = {"": component_device} + else: + current_device_map = None + elif isinstance(final_device_map, str): + current_device_map = final_device_map # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names class_name = class_name[4:] if class_name.startswith("Flax") else class_name From a31e59ed79a07852801038791acf88d1a5d4181a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Aug 2025 21:46:17 +0530 Subject: [PATCH 2/5] up --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- tests/pipelines/test_pipelines_common.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a2811aa1e2dc..b1af81cb1e85 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,6 +67,7 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card +from ..utils.testing_utils import torch_device from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module @@ -109,7 +110,7 @@ LIBRARIES.append(library) # TODO: support single-device namings -SUPPORTED_DEVICE_MAP = ["balanced", "cuda"] +SUPPORTED_DEVICE_MAP = ["balanced"] + [torch_device] logger = logging.get_logger(__name__) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 387eb6a614f9..945819895c06 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2339,6 +2339,26 @@ def test_torch_dtype_dict(self): f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", ) + @require_torch_accelerator + def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = torch.manual_seed(0) + out = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device) + inputs = self.get_dummy_inputs(torch_device) + loaded_out = loaded_pipe(**inputs)[0] + max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() + self.assertLess(max_diff, expected_max_difference) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From 30c575be1f1d2d31098783a561e76f7f71dda5b6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 12 Aug 2025 22:00:16 +0530 Subject: [PATCH 3/5] up --- tests/pipelines/test_pipelines_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 945819895c06..ed6a56c5faf5 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2354,7 +2354,10 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4 with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device) - inputs = self.get_dummy_inputs(torch_device) + for component in loaded_pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + inputs["generator"] = torch.manual_seed(0) loaded_out = loaded_pipe(**inputs)[0] max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() self.assertLess(max_diff, expected_max_difference) From 3dd4fb23cd55256e3aae46d7c4a5e33d2a18f974 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Aug 2025 08:09:04 +0530 Subject: [PATCH 4/5] empty From f657b6bd5858a35c55fab42ca3f47d79bfa34f08 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Aug 2025 14:00:36 +0530 Subject: [PATCH 5/5] up --- src/diffusers/pipelines/pipeline_utils.py | 4 +--- src/diffusers/utils/torch_utils.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b1af81cb1e85..d231989973e4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,7 +67,6 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.testing_utils import torch_device from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module @@ -109,8 +108,7 @@ for library in LOADABLE_CLASSES: LIBRARIES.append(library) -# TODO: support single-device namings -SUPPORTED_DEVICE_MAP = ["balanced"] + [torch_device] +SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] logger = logging.get_logger(__name__) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index dd54cb2b9186..5bc708a60c29 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -15,6 +15,7 @@ PyTorch utilities: Utilities related to PyTorch """ +import functools from typing import List, Optional, Tuple, Union from . import logging @@ -168,6 +169,7 @@ def get_torch_cuda_device_capability(): return None +@functools.lru_cache def get_device(): if torch.cuda.is_available(): return "cuda"