Skip to content

Commit c30d472

Browse files
authored
Add option to set dtype in pipeline.to() method (huggingface#2317)
add test_to_dtype to check pipe.to(fp16)
1 parent 3e82810 commit c30d472

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

pipelines/pipeline_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,13 @@ def is_saveable_module(name, value):
512512

513513
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
514514

515-
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
516-
if torch_device is None:
515+
def to(
516+
self,
517+
torch_device: Optional[Union[str, torch.device]] = None,
518+
torch_dtype: Optional[torch.dtype] = None,
519+
silence_dtype_warnings: bool = False,
520+
):
521+
if torch_device is None and torch_dtype is None:
517522
return self
518523

519524
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -550,6 +555,7 @@ def module_is_offloaded(module):
550555
for name in module_names.keys():
551556
module = getattr(self, name)
552557
if isinstance(module, torch.nn.Module):
558+
module.to(torch_device, torch_dtype)
553559
if (
554560
module.dtype == torch.float16
555561
and str(torch_device) in ["cpu"]
@@ -563,7 +569,6 @@ def module_is_offloaded(module):
563569
" support for`float16` operations on this device in PyTorch. Please, remove the"
564570
" `torch_dtype=torch.float16` argument, or use another device for inference."
565571
)
566-
module.to(torch_device)
567572
return self
568573

569574
@property

0 commit comments

Comments
 (0)