Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0ea501b
add accelerate to load models with smaller memory footprint
piEsposito Sep 5, 2022
7631dd6
remove low_cpu_mem_usage as it is reduntant
piEsposito Sep 12, 2022
973eb23
Merge branch 'main' of github.com:huggingface/diffusers into main
piEsposito Sep 12, 2022
8592e23
move accelerate init weights context to modelling utils
piEsposito Sep 16, 2022
76b8e4a
add test to ensure results are the same when loading with accelerate
piEsposito Sep 16, 2022
dd7f9b9
add tests to ensure ram usage gets lower when using accelerate
piEsposito Sep 16, 2022
ec5f7aa
move accelerate logic to single snippet under modelling utils and rem…
piEsposito Sep 16, 2022
ae5f56d
Merge branch 'huggingface:main' into main
piEsposito Sep 16, 2022
8392e3f
format code using to pass quality check
piEsposito Sep 16, 2022
615054a
fix imports with isor
piEsposito Sep 16, 2022
75c08a9
add accelerate to test extra deps
piEsposito Sep 16, 2022
7e06f3d
Merge branch 'main' into main
piEsposito Sep 16, 2022
6189b86
only import accelerate if device_map is set to auto
piEsposito Sep 21, 2022
02818b5
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Sep 21, 2022
dc14ace
Merge branch 'main' of github.com:huggingface/diffusers into main
piEsposito Sep 21, 2022
bc51061
move accelerate availability check to diffusers import utils
piEsposito Sep 22, 2022
ad1b55d
Merge remote-tracking branch 'upstream/main' into main
piEsposito Sep 22, 2022
e020d73
format code
piEsposito Sep 22, 2022
c3778bb
Merge branch 'main' into main
piEsposito Sep 22, 2022
0e2319d
Merge branch 'main' into main
piEsposito Oct 3, 2022
25e07d8
Merge branch 'main' into main
patrickvonplaten Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
"accelerate>=0.12.0"
]

# this is a lookup table with items like:
Expand Down
101 changes: 69 additions & 32 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch import Tensor, device

from diffusers.utils import is_accelerate_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
Expand Down Expand Up @@ -293,33 +294,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I oversaw this the first time.
@piEsposito could you also add some docstring here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. 3,4 lines under line 264


user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
Expand Down Expand Up @@ -391,25 +372,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

# restore default dtype
state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)

accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)

loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)

state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)

if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info

return model
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
USE_TF,
USE_TORCH,
DummyObject,
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_modelcards_available,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@
except importlib_metadata.PackageNotFoundError:
_scipy_available = False

_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
_accelerate_version = importlib_metadata.version("accelerate")
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -196,6 +203,10 @@ def is_scipy_available():
return _scipy_available


def is_accelerate_available():
return _accelerate_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down
70 changes: 70 additions & 0 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import tracemalloc
import unittest

import torch
Expand Down Expand Up @@ -133,6 +135,74 @@ def test_from_pretrained_hub(self):

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device)
image = model(**self.dummy_input).sample

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()

noise = torch.randn(
1,
model_accelerate.config.in_channels,
model_accelerate.config.sample_size,
model_accelerate.config.sample_size,
generator=torch.manual_seed(0),
)
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

arr_accelerate = model_accelerate(noise, time_step)["sample"]

# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
arr_normal_load = model_normal_load(noise, time_step)["sample"]

assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_memory_footprint_gets_reduced(self):
torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory()

del model_accelerate
torch.cuda.empty_cache()
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory()

tracemalloc.stop()

assert peak_accelerate < peak_normal

def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval()
Expand Down