Skip to content

Commit 9159ed7

Browse files
committed
Fixes #12673.
Wrong default_stream is used. leading to wrong execution order when record_steram is enabled.
1 parent c8656ed commit 9159ed7

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,27 @@ def _pinned_memory_tensors(self):
153153
finally:
154154
pinned_dict = None
155155

156-
def _transfer_tensor_to_device(self, tensor, source_tensor):
156+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
157157
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158158
if self.record_stream:
159-
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
159+
tensor.data.record_stream(default_stream)
160160

161-
def _process_tensors_from_modules(self, pinned_memory=None):
161+
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
162162
for group_module in self.modules:
163163
for param in group_module.parameters():
164164
source = pinned_memory[param] if pinned_memory else param.data
165-
self._transfer_tensor_to_device(param, source)
165+
self._transfer_tensor_to_device(param, source, default_stream)
166166
for buffer in group_module.buffers():
167167
source = pinned_memory[buffer] if pinned_memory else buffer.data
168-
self._transfer_tensor_to_device(buffer, source)
168+
self._transfer_tensor_to_device(buffer, source, default_stream)
169169

170170
for param in self.parameters:
171171
source = pinned_memory[param] if pinned_memory else param.data
172-
self._transfer_tensor_to_device(param, source)
172+
self._transfer_tensor_to_device(param, source, default_stream)
173173

174174
for buffer in self.buffers:
175175
source = pinned_memory[buffer] if pinned_memory else buffer.data
176-
self._transfer_tensor_to_device(buffer, source)
176+
self._transfer_tensor_to_device(buffer, source, default_stream)
177177

178178
def _onload_from_disk(self):
179179
if self.stream is not None:
@@ -208,12 +208,13 @@ def _onload_from_memory(self):
208208
self.stream.synchronize()
209209

210210
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
211+
default_stream = self._torch_accelerator_module.current_stream()
211212
with context:
212213
if self.stream is not None:
213214
with self._pinned_memory_tensors() as pinned_memory:
214-
self._process_tensors_from_modules(pinned_memory)
215+
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
215216
else:
216-
self._process_tensors_from_modules(None)
217+
self._process_tensors_from_modules(None, default_stream=default_stream)
217218

218219
def _offload_to_disk(self):
219220
# TODO: we can potentially optimize this code path by checking if the _all_ the desired

0 commit comments

Comments
 (0)