Skip to content

Commit 85a916b

Browse files
authored
make group offloading work with disk/nvme transfers (#11682)
* start implementing disk offloading in group. * delete diff file. * updates.patch * offload_to_disk_path * check if safetensors already exist. * add test and clarify. * updates * update todos. * update more docs. * update docs
1 parent 3287ce2 commit 85a916b

File tree

4 files changed

+134
-14
lines changed

4 files changed

+134
-14
lines changed

docs/source/en/optimization/memory.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
302302

303303
</Tip>
304304

305+
### Offloading to disk
306+
307+
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
308+
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
309+
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
310+
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
311+
305312
## Layerwise casting
306313

307314
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.

src/diffusers/hooks/group_offloading.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from contextlib import contextmanager, nullcontext
1617
from typing import Dict, List, Optional, Set, Tuple, Union
1718

19+
import safetensors.torch
1820
import torch
1921

2022
from ..utils import get_logger, is_accelerate_available
@@ -59,6 +61,7 @@ def __init__(
5961
record_stream: Optional[bool] = False,
6062
low_cpu_mem_usage: bool = False,
6163
onload_self: bool = True,
64+
offload_to_disk_path: Optional[str] = None,
6265
) -> None:
6366
self.modules = modules
6467
self.offload_device = offload_device
@@ -72,7 +75,26 @@ def __init__(
7275
self.record_stream = record_stream
7376
self.onload_self = onload_self
7477
self.low_cpu_mem_usage = low_cpu_mem_usage
75-
self.cpu_param_dict = self._init_cpu_param_dict()
78+
79+
self.offload_to_disk_path = offload_to_disk_path
80+
self._is_offloaded_to_disk = False
81+
82+
if self.offload_to_disk_path:
83+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
84+
85+
all_tensors = []
86+
for module in self.modules:
87+
all_tensors.extend(list(module.parameters()))
88+
all_tensors.extend(list(module.buffers()))
89+
all_tensors.extend(self.parameters)
90+
all_tensors.extend(self.buffers)
91+
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
92+
93+
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
94+
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
95+
self.cpu_param_dict = {}
96+
else:
97+
self.cpu_param_dict = self._init_cpu_param_dict()
7698

7799
if self.stream is None and self.record_stream:
78100
raise ValueError("`record_stream` cannot be True when `stream` is None.")
@@ -124,6 +146,30 @@ def onload_(self):
124146
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
125147
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
126148

149+
if self.offload_to_disk_path:
150+
if self.stream is not None:
151+
# Wait for previous Host->Device transfer to complete
152+
self.stream.synchronize()
153+
154+
with context:
155+
if self.stream is not None:
156+
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
157+
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
158+
for key, tensor_obj in self.key_to_tensor.items():
159+
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
160+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
161+
if self.record_stream:
162+
tensor_obj.data.record_stream(current_stream)
163+
else:
164+
# Load directly to the target device (synchronous)
165+
onload_device = (
166+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
167+
)
168+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
169+
for key, tensor_obj in self.key_to_tensor.items():
170+
tensor_obj.data = loaded_tensors[key]
171+
return
172+
127173
if self.stream is not None:
128174
# Wait for previous Host->Device transfer to complete
129175
self.stream.synchronize()
@@ -169,6 +215,26 @@ def onload_(self):
169215
@torch.compiler.disable()
170216
def offload_(self):
171217
r"""Offloads the group of modules to the offload_device."""
218+
if self.offload_to_disk_path:
219+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
220+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222+
# we perform a write.
223+
# Check if the file has been saved in this session or if it already exists on disk.
224+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
226+
tensors_to_save = {
227+
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
228+
}
229+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
230+
231+
# The group is now considered offloaded to disk for the rest of the session.
232+
self._is_offloaded_to_disk = True
233+
234+
# We do this to free up the RAM which is still holding the up tensor data.
235+
for tensor_obj in self.tensor_to_key.keys():
236+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
237+
return
172238

173239
torch_accelerator_module = (
174240
getattr(torch, torch.accelerator.current_accelerator().type)
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
205271

206272
_is_stateful = False
207273

208-
def __init__(
209-
self,
210-
group: ModuleGroup,
211-
next_group: Optional[ModuleGroup] = None,
212-
) -> None:
274+
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
213275
self.group = group
214276
self.next_group = next_group
215277

@@ -363,6 +425,7 @@ def apply_group_offloading(
363425
use_stream: bool = False,
364426
record_stream: bool = False,
365427
low_cpu_mem_usage: bool = False,
428+
offload_to_disk_path: Optional[str] = None,
366429
) -> None:
367430
r"""
368431
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -401,6 +464,9 @@ def apply_group_offloading(
401464
offload_type (`str`, defaults to "block_level"):
402465
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
403466
"block_level".
467+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
468+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
469+
RAM environment settings where a reasonable speed-memory trade-off is desired.
404470
num_blocks_per_group (`int`, *optional*):
405471
The number of blocks per group when using offload_type="block_level". This is required when using
406472
offload_type="block_level".
@@ -458,6 +524,7 @@ def apply_group_offloading(
458524
num_blocks_per_group=num_blocks_per_group,
459525
offload_device=offload_device,
460526
onload_device=onload_device,
527+
offload_to_disk_path=offload_to_disk_path,
461528
non_blocking=non_blocking,
462529
stream=stream,
463530
record_stream=record_stream,
@@ -468,6 +535,7 @@ def apply_group_offloading(
468535
module=module,
469536
offload_device=offload_device,
470537
onload_device=onload_device,
538+
offload_to_disk_path=offload_to_disk_path,
471539
non_blocking=non_blocking,
472540
stream=stream,
473541
record_stream=record_stream,
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
486554
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
487555
record_stream: Optional[bool] = False,
488556
low_cpu_mem_usage: bool = False,
557+
offload_to_disk_path: Optional[str] = None,
489558
) -> None:
490559
r"""
491560
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
496565
The module to which group offloading is applied.
497566
offload_device (`torch.device`):
498567
The device to which the group of modules are offloaded. This should typically be the CPU.
568+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
569+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
570+
RAM environment settings where a reasonable speed-memory trade-off is desired.
499571
onload_device (`torch.device`):
500572
The device to which the group of modules are onloaded.
501573
non_blocking (`bool`):
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
535607
modules=current_modules,
536608
offload_device=offload_device,
537609
onload_device=onload_device,
610+
offload_to_disk_path=offload_to_disk_path,
538611
offload_leader=current_modules[-1],
539612
onload_leader=current_modules[0],
540613
non_blocking=non_blocking,
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
567640
modules=unmatched_modules,
568641
offload_device=offload_device,
569642
onload_device=onload_device,
643+
offload_to_disk_path=offload_to_disk_path,
570644
offload_leader=module,
571645
onload_leader=module,
572646
parameters=parameters,
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
590664
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
591665
record_stream: Optional[bool] = False,
592666
low_cpu_mem_usage: bool = False,
667+
offload_to_disk_path: Optional[str] = None,
593668
) -> None:
594669
r"""
595670
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
604679
The device to which the group of modules are offloaded. This should typically be the CPU.
605680
onload_device (`torch.device`):
606681
The device to which the group of modules are onloaded.
682+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
683+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
684+
RAM environment settings where a reasonable speed-memory trade-off is desired.
607685
non_blocking (`bool`):
608686
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
609687
and data transfer.
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
629707
modules=[submodule],
630708
offload_device=offload_device,
631709
onload_device=onload_device,
710+
offload_to_disk_path=offload_to_disk_path,
632711
offload_leader=submodule,
633712
onload_leader=submodule,
634713
non_blocking=non_blocking,
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
675754
onload_device=onload_device,
676755
offload_leader=parent_module,
677756
onload_leader=parent_module,
757+
offload_to_disk_path=offload_to_disk_path,
678758
parameters=parameters,
679759
buffers=buffers,
680760
non_blocking=non_blocking,
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
693773
modules=[],
694774
offload_device=offload_device,
695775
onload_device=onload_device,
776+
offload_to_disk_path=offload_to_disk_path,
696777
offload_leader=module,
697778
onload_leader=module,
698779
parameters=None,

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def enable_group_offload(
548548
use_stream: bool = False,
549549
record_stream: bool = False,
550550
low_cpu_mem_usage=False,
551+
offload_to_disk_path: Optional[str] = None,
551552
) -> None:
552553
r"""
553554
Activates group offloading for the current model.
@@ -588,15 +589,16 @@ def enable_group_offload(
588589
f"open an issue at https://github.com/huggingface/diffusers/issues."
589590
)
590591
apply_group_offloading(
591-
self,
592-
onload_device,
593-
offload_device,
594-
offload_type,
595-
num_blocks_per_group,
596-
non_blocking,
597-
use_stream,
598-
record_stream,
592+
module=self,
593+
onload_device=onload_device,
594+
offload_device=offload_device,
595+
offload_type=offload_type,
596+
num_blocks_per_group=num_blocks_per_group,
597+
non_blocking=non_blocking,
598+
use_stream=use_stream,
599+
record_stream=record_stream,
599600
low_cpu_mem_usage=low_cpu_mem_usage,
601+
offload_to_disk_path=offload_to_disk_path,
600602
)
601603

602604
def save_pretrained(

tests/models/test_modeling_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import gc
18+
import glob
1819
import inspect
1920
import json
2021
import os
@@ -1693,6 +1694,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16931694
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
16941695
_ = model(**inputs_dict)[0]
16951696

1697+
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
1698+
@require_torch_accelerator
1699+
@torch.no_grad()
1700+
def test_group_offloading_with_disk(self, record_stream, offload_type):
1701+
torch.manual_seed(0)
1702+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1703+
model = self.model_class(**init_dict)
1704+
1705+
if not getattr(model, "_supports_group_offloading", True):
1706+
return
1707+
1708+
torch.manual_seed(0)
1709+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1710+
model = self.model_class(**init_dict)
1711+
model.eval()
1712+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1713+
with tempfile.TemporaryDirectory() as tmpdir:
1714+
model.enable_group_offload(
1715+
torch_device,
1716+
offload_type=offload_type,
1717+
offload_to_disk_path=tmpdir,
1718+
use_stream=True,
1719+
record_stream=record_stream,
1720+
**additional_kwargs,
1721+
)
1722+
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1723+
assert has_safetensors, "No safetensors found in the directory."
1724+
_ = model(**inputs_dict)[0]
1725+
16961726
def test_auto_model(self, expected_max_diff=5e-5):
16971727
if self.forward_requires_fresh_args:
16981728
model = self.model_class(**self.init_dict)

0 commit comments

Comments
 (0)