Skip to content

Commit e567401

Browse files
yiyixuxusayakpaul
andauthored
adding back test_conversion_when_using_device_map (#7704)
* style * Fix device map nits (#7705) --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent b5c8b55 commit e567401

File tree

5 files changed

+54
-49
lines changed

5 files changed

+54
-49
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
9696
_deps = [
9797
"Pillow", # keep the PIL.Image.Resampling deprecation away
98-
"accelerate>=0.11.0",
98+
"accelerate>=0.29.3",
9999
"compel==0.1.8",
100100
"datasets",
101101
"filelock",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# 2. run `make deps_table_update`
44
deps = {
55
"Pillow": "Pillow",
6-
"accelerate": "accelerate>=0.11.0",
6+
"accelerate": "accelerate>=0.29.3",
77
"compel": "compel==0.1.8",
88
"datasets": "datasets",
99
"filelock": "filelock",

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
700700
offload_state_dict=offload_state_dict,
701701
dtype=torch_dtype,
702702
force_hooks=True,
703+
strict=True,
703704
)
704705
except AttributeError as e:
705706
# When using accelerate loading, we do not have the ability to load the state

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -571,15 +571,17 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
571571

572572
# Obtain a dictionary mapping the model-level components to the available
573573
# devices based on the maximum memory and the model sizes.
574-
device_id_component_mapping = _assign_components_to_devices(
575-
module_sizes, max_memory, device_mapping_strategy=device_map
576-
)
574+
final_device_map = None
575+
if len(max_memory) > 0:
576+
device_id_component_mapping = _assign_components_to_devices(
577+
module_sizes, max_memory, device_mapping_strategy=device_map
578+
)
577579

578-
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
579-
final_device_map = {}
580-
for device_id, components in device_id_component_mapping.items():
581-
for component in components:
582-
final_device_map[component] = device_id
580+
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
581+
final_device_map = {}
582+
for device_id, components in device_id_component_mapping.items():
583+
for component in components:
584+
final_device_map[component] = device_id
583585

584586
return final_device_map
585587

tests/models/test_attention_processor.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import tempfile
12
import unittest
23

4+
import numpy as np
35
import torch
46

7+
from diffusers import DiffusionPipeline
58
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
69

710

@@ -77,42 +80,41 @@ def test_only_cross_attention(self):
7780

7881
class DeprecatedAttentionBlockTests(unittest.TestCase):
7982
def test_conversion_when_using_device_map(self):
80-
# To-DO for Sayak: enable this test again and to test `device_map='balanced'` once we have this in accelerate https://github.com/huggingface/accelerate/pull/2641
81-
pass
82-
# pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None)
83-
84-
# pre_conversion = pipe(
85-
# "foo",
86-
# num_inference_steps=2,
87-
# generator=torch.Generator("cpu").manual_seed(0),
88-
# output_type="np",
89-
# ).images
90-
91-
# # the initial conversion succeeds
92-
# pipe = DiffusionPipeline.from_pretrained(
93-
# "hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None
94-
# )
95-
96-
# conversion = pipe(
97-
# "foo",
98-
# num_inference_steps=2,
99-
# generator=torch.Generator("cpu").manual_seed(0),
100-
# output_type="np",
101-
# ).images
102-
103-
# with tempfile.TemporaryDirectory() as tmpdir:
104-
# # save the converted model
105-
# pipe.save_pretrained(tmpdir)
106-
107-
# # can also load the converted weights
108-
# pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None)
109-
110-
# after_conversion = pipe(
111-
# "foo",
112-
# num_inference_steps=2,
113-
# generator=torch.Generator("cpu").manual_seed(0),
114-
# output_type="np",
115-
# ).images
116-
117-
# self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-5))
118-
# self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-5))
83+
pipe = DiffusionPipeline.from_pretrained(
84+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
85+
)
86+
87+
pre_conversion = pipe(
88+
"foo",
89+
num_inference_steps=2,
90+
generator=torch.Generator("cpu").manual_seed(0),
91+
output_type="np",
92+
).images
93+
94+
# the initial conversion succeeds
95+
pipe = DiffusionPipeline.from_pretrained(
96+
"hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None
97+
)
98+
99+
conversion = pipe(
100+
"foo",
101+
num_inference_steps=2,
102+
generator=torch.Generator("cpu").manual_seed(0),
103+
output_type="np",
104+
).images
105+
106+
with tempfile.TemporaryDirectory() as tmpdir:
107+
# save the converted model
108+
pipe.save_pretrained(tmpdir)
109+
110+
# can also load the converted weights
111+
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None)
112+
after_conversion = pipe(
113+
"foo",
114+
num_inference_steps=2,
115+
generator=torch.Generator("cpu").manual_seed(0),
116+
output_type="np",
117+
).images
118+
119+
self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3))
120+
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))

0 commit comments

Comments
 (0)