From da98bb765a7c891efbcb2318bfd1c7af92b3c76e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Apr 2024 09:06:24 +0530 Subject: [PATCH 1/4] debug --- src/diffusers/pipelines/pipeline_loading_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index fd646f8e0d95..8b48fd323078 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -475,6 +475,7 @@ def _assign_components_to_devices( device_ids = list(device_memory.keys()) device_cycle = device_ids + device_ids[::-1] device_memory = device_memory.copy() + print(f"{device_memory=}, {device_ids=}, {device_cycle=}") device_id_component_mapping = {} current_device_index = 0 From d840cdae3c08c2bf83903726ae79e40dec0dfd38 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Apr 2024 09:07:16 +0530 Subject: [PATCH 2/4] debug --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8b48fd323078..90bdc4507132 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -475,7 +475,7 @@ def _assign_components_to_devices( device_ids = list(device_memory.keys()) device_cycle = device_ids + device_ids[::-1] device_memory = device_memory.copy() - print(f"{device_memory=}, {device_ids=}, {device_cycle=}") + print(f"{device_memory=}, {device_ids=}, {device_cycle=}, {module_sizes=}") device_id_component_mapping = {} current_device_index = 0 From e5f988cafa83d7414b1e95bc1c5522a9c0794f67 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Apr 2024 09:13:16 +0530 Subject: [PATCH 3/4] fix --- .../pipelines/pipeline_loading_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 90bdc4507132..5f99c2728219 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -569,15 +569,17 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic # Obtain a dictionary mapping the model-level components to the available # devices based on the maximum memory and the model sizes. - device_id_component_mapping = _assign_components_to_devices( - module_sizes, max_memory, device_mapping_strategy=device_map - ) + final_device_map = None + if len(max_memory) > 0: + device_id_component_mapping = _assign_components_to_devices( + module_sizes, max_memory, device_mapping_strategy=device_map + ) - # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` - final_device_map = {} - for device_id, components in device_id_component_mapping.items(): - for component in components: - final_device_map[component] = device_id + # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` + final_device_map = {} + for device_id, components in device_id_component_mapping.items(): + for component in components: + final_device_map[component] = device_id return final_device_map From 3ea4baf257287e95bbc8f5dc80cec1ba39caf17b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Apr 2024 09:16:17 +0530 Subject: [PATCH 4/4] fix --- src/diffusers/pipelines/pipeline_loading_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 5f99c2728219..03a61636991c 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -475,7 +475,6 @@ def _assign_components_to_devices( device_ids = list(device_memory.keys()) device_cycle = device_ids + device_ids[::-1] device_memory = device_memory.copy() - print(f"{device_memory=}, {device_ids=}, {device_cycle=}, {module_sizes=}") device_id_component_mapping = {} current_device_index = 0