Skip to content

Commit b33bd91

Browse files
authored
Add option to set dtype in pipeline.to() method (#2317)
add test_to_dtype to check pipe.to(fp16)
1 parent 1fcf279 commit b33bd91

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,13 @@ def is_saveable_module(name, value):
512512

513513
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
514514

515-
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
516-
if torch_device is None:
515+
def to(
516+
self,
517+
torch_device: Optional[Union[str, torch.device]] = None,
518+
torch_dtype: Optional[torch.dtype] = None,
519+
silence_dtype_warnings: bool = False,
520+
):
521+
if torch_device is None and torch_dtype is None:
517522
return self
518523

519524
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -550,6 +555,7 @@ def module_is_offloaded(module):
550555
for name in module_names.keys():
551556
module = getattr(self, name)
552557
if isinstance(module, torch.nn.Module):
558+
module.to(torch_device, torch_dtype)
553559
if (
554560
module.dtype == torch.float16
555561
and str(torch_device) in ["cpu"]
@@ -563,7 +569,6 @@ def module_is_offloaded(module):
563569
" support for`float16` operations on this device in PyTorch. Please, remove the"
564570
" `torch_dtype=torch.float16` argument, or use another device for inference."
565571
)
566-
module.to(torch_device)
567572
return self
568573

569574
@property

tests/test_pipelines_common.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,8 @@ def test_float16_inference(self):
344344
pipe.to(torch_device)
345345
pipe.set_progress_bar_config(disable=None)
346346

347-
for name, module in components.items():
348-
if hasattr(module, "half"):
349-
components[name] = module.half()
350347
pipe_fp16 = self.pipeline_class(**components)
351-
pipe_fp16.to(torch_device)
348+
pipe_fp16.to(torch_device, torch.float16)
352349
pipe_fp16.set_progress_bar_config(disable=None)
353350

354351
output = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -447,6 +444,18 @@ def test_to_device(self):
447444
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
448445
self.assertTrue(np.isnan(output_cuda).sum() == 0)
449446

447+
def test_to_dtype(self):
448+
components = self.get_dummy_components()
449+
pipe = self.pipeline_class(**components)
450+
pipe.set_progress_bar_config(disable=None)
451+
452+
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
453+
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
454+
455+
pipe.to(torch_dtype=torch.float16)
456+
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
457+
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
458+
450459
def test_attention_slicing_forward_pass(self):
451460
self._test_attention_slicing_forward_pass()
452461

0 commit comments

Comments
 (0)