Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,17 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic

# Obtain a dictionary mapping the model-level components to the available
# devices based on the maximum memory and the model sizes.
device_id_component_mapping = _assign_components_to_devices(
module_sizes, max_memory, device_mapping_strategy=device_map
)
final_device_map = None
if len(max_memory) > 0:
device_id_component_mapping = _assign_components_to_devices(
module_sizes, max_memory, device_mapping_strategy=device_map
)

# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
final_device_map = {}
for device_id, components in device_id_component_mapping.items():
for component in components:
final_device_map[component] = device_id
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
final_device_map = {}
for device_id, components in device_id_component_mapping.items():
for component in components:
final_device_map[component] = device_id

return final_device_map

Expand Down