Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,27 @@ def _pinned_memory_tensors(self):
finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor):
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
tensor.data.record_stream(default_stream)

def _process_tensors_from_modules(self, pinned_memory=None):
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, default_stream)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, default_stream)

for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, default_stream)

for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, default_stream)

def _onload_from_disk(self):
if self.stream is not None:
Expand Down Expand Up @@ -208,10 +208,12 @@ def _onload_from_memory(self):
self.stream.synchronize()

context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None

with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
else:
self._process_tensors_from_modules(None)

Expand Down
3 changes: 0 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,9 +1814,6 @@ def _run_forward(model, inputs_dict):
torch.manual_seed(0)
return model(**inputs_dict)[0]

if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
Expand Down
Loading