File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -240,15 +240,15 @@ def _offload_to_memory(self):
240240
241241 for group_module in self .modules :
242242 for param in group_module .parameters ():
243- if self .record_stream and param .device .type == 'cuda ' :
243+ if self .record_stream and param .device .type != 'cpu ' :
244244 param .data .record_stream (current_stream )
245245 param .data = self .cpu_param_dict [param ]
246246 for param in self .parameters :
247- if self .record_stream and param .device .type == 'cuda ' :
247+ if self .record_stream and param .device .type != 'cpu ' :
248248 param .data .record_stream (current_stream )
249249 param .data = self .cpu_param_dict [param ]
250250 for buffer in self .buffers :
251- if self .record_stream and buffer .device .type == 'cuda ' :
251+ if self .record_stream and buffer .device .type != 'cpu ' :
252252 buffer .data .record_stream (current_stream )
253253 buffer .data = self .cpu_param_dict [buffer ]
254254 else :
You can’t perform that action at this time.
0 commit comments