@@ -153,27 +153,27 @@ def _pinned_memory_tensors(self):
153153 finally :
154154 pinned_dict = None
155155
156- def _transfer_tensor_to_device (self , tensor , source_tensor ):
156+ def _transfer_tensor_to_device (self , tensor , source_tensor , default_stream ):
157157 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
158158 if self .record_stream :
159- tensor .data .record_stream (self . _torch_accelerator_module . current_stream () )
159+ tensor .data .record_stream (default_stream )
160160
161- def _process_tensors_from_modules (self , pinned_memory = None ):
161+ def _process_tensors_from_modules (self , pinned_memory = None , default_stream = None ):
162162 for group_module in self .modules :
163163 for param in group_module .parameters ():
164164 source = pinned_memory [param ] if pinned_memory else param .data
165- self ._transfer_tensor_to_device (param , source )
165+ self ._transfer_tensor_to_device (param , source , default_stream )
166166 for buffer in group_module .buffers ():
167167 source = pinned_memory [buffer ] if pinned_memory else buffer .data
168- self ._transfer_tensor_to_device (buffer , source )
168+ self ._transfer_tensor_to_device (buffer , source , default_stream )
169169
170170 for param in self .parameters :
171171 source = pinned_memory [param ] if pinned_memory else param .data
172- self ._transfer_tensor_to_device (param , source )
172+ self ._transfer_tensor_to_device (param , source , default_stream )
173173
174174 for buffer in self .buffers :
175175 source = pinned_memory [buffer ] if pinned_memory else buffer .data
176- self ._transfer_tensor_to_device (buffer , source )
176+ self ._transfer_tensor_to_device (buffer , source , default_stream )
177177
178178 def _onload_from_disk (self ):
179179 if self .stream is not None :
@@ -208,12 +208,13 @@ def _onload_from_memory(self):
208208 self .stream .synchronize ()
209209
210210 context = nullcontext () if self .stream is None else self ._torch_accelerator_module .stream (self .stream )
211+ default_stream = self ._torch_accelerator_module .current_stream ()
211212 with context :
212213 if self .stream is not None :
213214 with self ._pinned_memory_tensors () as pinned_memory :
214- self ._process_tensors_from_modules (pinned_memory )
215+ self ._process_tensors_from_modules (pinned_memory , default_stream = default_stream )
215216 else :
216- self ._process_tensors_from_modules (None )
217+ self ._process_tensors_from_modules (None , default_stream = default_stream )
217218
218219 def _offload_to_disk (self ):
219220 # TODO: we can potentially optimize this code path by checking if the _all_ the desired
0 commit comments