Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f51cccc
add method to enable cuda with minimal gpu usage to stable diffusion
piEsposito Oct 14, 2022
88acaf3
add test to minimal cuda memory usage
piEsposito Oct 14, 2022
c07834f
Merge branch 'main' into main
piEsposito Oct 14, 2022
a065897
ensure all models but unet are onn torch.float32
piEsposito Oct 17, 2022
4b4c69a
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Oct 17, 2022
0b601e6
Merge branch 'main' into main
piEsposito Oct 17, 2022
1501171
Merge branch 'main' into main
piEsposito Oct 18, 2022
9af3535
Merge branch 'main' into main
piEsposito Oct 19, 2022
965dfe1
move to cpu_offload along with minor internal changes to make it work
piEsposito Oct 20, 2022
6e81ac5
Merge branch 'main' of github.com:piEsposito/diffusers into main
piEsposito Oct 20, 2022
c7b3646
make it test against accelerate master branch
piEsposito Oct 20, 2022
45a61d8
coming back, its official: I don't know how to make it test againt th…
piEsposito Oct 20, 2022
3af822b
make it install accelerate from master on tests
piEsposito Oct 20, 2022
358f59a
go back to accelerate>=0.11
piEsposito Oct 20, 2022
9793bc8
undo prettier formatting on yml files
piEsposito Oct 20, 2022
d1f00ba
undo prettier formatting on yml files againn
piEsposito Oct 20, 2022
c728184
Merge branch 'main' into main
piEsposito Oct 20, 2022
b6ebfb4
Merge branch 'main' into main
piEsposito Oct 21, 2022
49a1711
Merge branch 'main' into main
piEsposito Oct 23, 2022
b00fe03
Merge branch 'main' into main
piEsposito Oct 24, 2022
ce745c9
Merge branch 'main' into main
piEsposito Oct 25, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand Down Expand Up @@ -80,6 +81,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
shell: arch -arch arm64 bash {0}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand All @@ -58,8 +59,6 @@ jobs:
name: torch_test_reports
path: reports



run_examples_single_gpu:
name: Examples tests
runs-on: [ self-hosted, docker-gpu, single-gpu ]
Expand All @@ -83,6 +82,7 @@ jobs:
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate

- name: Environment
run: |
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def device(self) -> torch.device:
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.device == torch.device("meta"):
return torch.device("cpu")
return module.device
return torch.device("cpu")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -118,6 +119,18 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def cuda_with_minimal_gpu_usage(self):
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")
self.enable_attention_slicing(1)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
cpu_offload(cpu_offloaded_model, device)

@torch.no_grad()
def __call__(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,23 @@ def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
tracemalloc.stop()

assert peak_accelerate < peak_normal

@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"

pipeline = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
)
pipeline.cuda_with_minimal_gpu_usage()

_ = pipeline(prompt)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 0.8 GB is allocated
assert mem_bytes < 0.8 * 10**9