Skip to content

Commit 6c64741

Browse files
anton-lpcuencapatil-suraj
authored
Raise an error when moving an fp16 pipeline to CPU (#749)
* Raise an error when moving an fp16 pipeline to CPU * Raise an error when moving an fp16 pipeline to CPU * style * Update src/diffusers/pipeline_utils.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/pipeline_utils.py Co-authored-by: Suraj Patil <[email protected]> * Improve the message * cuda * Update tests/test_pipelines.py Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 3383f77 commit 6c64741

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
166166
for name in module_names.keys():
167167
module = getattr(self, name)
168168
if isinstance(module, torch.nn.Module):
169+
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
170+
raise ValueError(
171+
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
172+
"due to the lack of support for `float16` operations on those devices in PyTorch. "
173+
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
174+
)
169175
module.to(torch_device)
170176
return self
171177

tests/test_pipelines.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,17 @@ def to(self, device):
188188

189189
return extract
190190

191+
def test_pipeline_fp16_cpu_error(self):
192+
model = self.dummy_uncond_unet
193+
scheduler = DDPMScheduler(num_train_timesteps=10)
194+
pipe = DDIMPipeline(model.half(), scheduler)
195+
196+
if str(torch_device) in ["cpu", "mps"]:
197+
self.assertRaises(ValueError, pipe.to, torch_device)
198+
else:
199+
# moving the pipeline to GPU should work
200+
pipe.to(torch_device)
201+
191202
def test_ddim(self):
192203
unet = self.dummy_uncond_unet
193204
scheduler = DDIMScheduler()

0 commit comments

Comments
 (0)