12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import os
15
16
from contextlib import contextmanager , nullcontext
16
17
from typing import Dict , List , Optional , Set , Tuple , Union
17
18
19
+ import safetensors .torch
18
20
import torch
19
21
20
22
from ..utils import get_logger , is_accelerate_available
@@ -59,6 +61,7 @@ def __init__(
59
61
record_stream : Optional [bool ] = False ,
60
62
low_cpu_mem_usage : bool = False ,
61
63
onload_self : bool = True ,
64
+ offload_to_disk_path : Optional [str ] = None ,
62
65
) -> None :
63
66
self .modules = modules
64
67
self .offload_device = offload_device
@@ -72,7 +75,26 @@ def __init__(
72
75
self .record_stream = record_stream
73
76
self .onload_self = onload_self
74
77
self .low_cpu_mem_usage = low_cpu_mem_usage
75
- self .cpu_param_dict = self ._init_cpu_param_dict ()
78
+
79
+ self .offload_to_disk_path = offload_to_disk_path
80
+ self ._is_offloaded_to_disk = False
81
+
82
+ if self .offload_to_disk_path :
83
+ self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ id (self )} .safetensors" )
84
+
85
+ all_tensors = []
86
+ for module in self .modules :
87
+ all_tensors .extend (list (module .parameters ()))
88
+ all_tensors .extend (list (module .buffers ()))
89
+ all_tensors .extend (self .parameters )
90
+ all_tensors .extend (self .buffers )
91
+ all_tensors = list (dict .fromkeys (all_tensors )) # Remove duplicates
92
+
93
+ self .tensor_to_key = {tensor : f"tensor_{ i } " for i , tensor in enumerate (all_tensors )}
94
+ self .key_to_tensor = {v : k for k , v in self .tensor_to_key .items ()}
95
+ self .cpu_param_dict = {}
96
+ else :
97
+ self .cpu_param_dict = self ._init_cpu_param_dict ()
76
98
77
99
if self .stream is None and self .record_stream :
78
100
raise ValueError ("`record_stream` cannot be True when `stream` is None." )
@@ -124,6 +146,30 @@ def onload_(self):
124
146
context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
125
147
current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
126
148
149
+ if self .offload_to_disk_path :
150
+ if self .stream is not None :
151
+ # Wait for previous Host->Device transfer to complete
152
+ self .stream .synchronize ()
153
+
154
+ with context :
155
+ if self .stream is not None :
156
+ # Load to CPU, pin, and async copy to device for overlapping transfer and compute
157
+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
158
+ for key , tensor_obj in self .key_to_tensor .items ():
159
+ pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
160
+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
161
+ if self .record_stream :
162
+ tensor_obj .data .record_stream (current_stream )
163
+ else :
164
+ # Load directly to the target device (synchronous)
165
+ onload_device = (
166
+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
167
+ )
168
+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
169
+ for key , tensor_obj in self .key_to_tensor .items ():
170
+ tensor_obj .data = loaded_tensors [key ]
171
+ return
172
+
127
173
if self .stream is not None :
128
174
# Wait for previous Host->Device transfer to complete
129
175
self .stream .synchronize ()
@@ -169,6 +215,26 @@ def onload_(self):
169
215
@torch .compiler .disable ()
170
216
def offload_ (self ):
171
217
r"""Offloads the group of modules to the offload_device."""
218
+ if self .offload_to_disk_path :
219
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222
+ # we perform a write.
223
+ # Check if the file has been saved in this session or if it already exists on disk.
224
+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
225
+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
226
+ tensors_to_save = {
227
+ key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
228
+ }
229
+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
230
+
231
+ # The group is now considered offloaded to disk for the rest of the session.
232
+ self ._is_offloaded_to_disk = True
233
+
234
+ # We do this to free up the RAM which is still holding the up tensor data.
235
+ for tensor_obj in self .tensor_to_key .keys ():
236
+ tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
237
+ return
172
238
173
239
torch_accelerator_module = (
174
240
getattr (torch , torch .accelerator .current_accelerator ().type )
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
205
271
206
272
_is_stateful = False
207
273
208
- def __init__ (
209
- self ,
210
- group : ModuleGroup ,
211
- next_group : Optional [ModuleGroup ] = None ,
212
- ) -> None :
274
+ def __init__ (self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None ) -> None :
213
275
self .group = group
214
276
self .next_group = next_group
215
277
@@ -363,6 +425,7 @@ def apply_group_offloading(
363
425
use_stream : bool = False ,
364
426
record_stream : bool = False ,
365
427
low_cpu_mem_usage : bool = False ,
428
+ offload_to_disk_path : Optional [str ] = None ,
366
429
) -> None :
367
430
r"""
368
431
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -401,6 +464,9 @@ def apply_group_offloading(
401
464
offload_type (`str`, defaults to "block_level"):
402
465
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
403
466
"block_level".
467
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
468
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
469
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
404
470
num_blocks_per_group (`int`, *optional*):
405
471
The number of blocks per group when using offload_type="block_level". This is required when using
406
472
offload_type="block_level".
@@ -458,6 +524,7 @@ def apply_group_offloading(
458
524
num_blocks_per_group = num_blocks_per_group ,
459
525
offload_device = offload_device ,
460
526
onload_device = onload_device ,
527
+ offload_to_disk_path = offload_to_disk_path ,
461
528
non_blocking = non_blocking ,
462
529
stream = stream ,
463
530
record_stream = record_stream ,
@@ -468,6 +535,7 @@ def apply_group_offloading(
468
535
module = module ,
469
536
offload_device = offload_device ,
470
537
onload_device = onload_device ,
538
+ offload_to_disk_path = offload_to_disk_path ,
471
539
non_blocking = non_blocking ,
472
540
stream = stream ,
473
541
record_stream = record_stream ,
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
486
554
stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
487
555
record_stream : Optional [bool ] = False ,
488
556
low_cpu_mem_usage : bool = False ,
557
+ offload_to_disk_path : Optional [str ] = None ,
489
558
) -> None :
490
559
r"""
491
560
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
496
565
The module to which group offloading is applied.
497
566
offload_device (`torch.device`):
498
567
The device to which the group of modules are offloaded. This should typically be the CPU.
568
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
569
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
570
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
499
571
onload_device (`torch.device`):
500
572
The device to which the group of modules are onloaded.
501
573
non_blocking (`bool`):
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
535
607
modules = current_modules ,
536
608
offload_device = offload_device ,
537
609
onload_device = onload_device ,
610
+ offload_to_disk_path = offload_to_disk_path ,
538
611
offload_leader = current_modules [- 1 ],
539
612
onload_leader = current_modules [0 ],
540
613
non_blocking = non_blocking ,
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
567
640
modules = unmatched_modules ,
568
641
offload_device = offload_device ,
569
642
onload_device = onload_device ,
643
+ offload_to_disk_path = offload_to_disk_path ,
570
644
offload_leader = module ,
571
645
onload_leader = module ,
572
646
parameters = parameters ,
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
590
664
stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
591
665
record_stream : Optional [bool ] = False ,
592
666
low_cpu_mem_usage : bool = False ,
667
+ offload_to_disk_path : Optional [str ] = None ,
593
668
) -> None :
594
669
r"""
595
670
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
604
679
The device to which the group of modules are offloaded. This should typically be the CPU.
605
680
onload_device (`torch.device`):
606
681
The device to which the group of modules are onloaded.
682
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
683
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
684
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
607
685
non_blocking (`bool`):
608
686
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
609
687
and data transfer.
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
629
707
modules = [submodule ],
630
708
offload_device = offload_device ,
631
709
onload_device = onload_device ,
710
+ offload_to_disk_path = offload_to_disk_path ,
632
711
offload_leader = submodule ,
633
712
onload_leader = submodule ,
634
713
non_blocking = non_blocking ,
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
675
754
onload_device = onload_device ,
676
755
offload_leader = parent_module ,
677
756
onload_leader = parent_module ,
757
+ offload_to_disk_path = offload_to_disk_path ,
678
758
parameters = parameters ,
679
759
buffers = buffers ,
680
760
non_blocking = non_blocking ,
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
693
773
modules = [],
694
774
offload_device = offload_device ,
695
775
onload_device = onload_device ,
776
+ offload_to_disk_path = offload_to_disk_path ,
696
777
offload_leader = module ,
697
778
onload_leader = module ,
698
779
parameters = None ,
0 commit comments