diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 83a336319c32..14d22db3ef36 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1423,6 +1423,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t device_type = torch_device.type device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) @@ -1472,7 +1473,7 @@ def maybe_free_model_hooks(self): hook.remove() # make sure the model is in the same state as before calling it - self.enable_model_cpu_offload() + self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" @@ -1508,6 +1509,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un device_type = torch_device.type device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True)