@@ -512,8 +512,13 @@ def is_saveable_module(name, value):
512
512
513
513
save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
514
514
515
- def to (self , torch_device : Optional [Union [str , torch .device ]] = None , silence_dtype_warnings : bool = False ):
516
- if torch_device is None :
515
+ def to (
516
+ self ,
517
+ torch_device : Optional [Union [str , torch .device ]] = None ,
518
+ torch_dtype : Optional [torch .dtype ] = None ,
519
+ silence_dtype_warnings : bool = False ,
520
+ ):
521
+ if torch_device is None and torch_dtype is None :
517
522
return self
518
523
519
524
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -550,6 +555,7 @@ def module_is_offloaded(module):
550
555
for name in module_names .keys ():
551
556
module = getattr (self , name )
552
557
if isinstance (module , torch .nn .Module ):
558
+ module .to (torch_device , torch_dtype )
553
559
if (
554
560
module .dtype == torch .float16
555
561
and str (torch_device ) in ["cpu" ]
@@ -563,7 +569,6 @@ def module_is_offloaded(module):
563
569
" support for`float16` operations on this device in PyTorch. Please, remove the"
564
570
" `torch_dtype=torch.float16` argument, or use another device for inference."
565
571
)
566
- module .to (torch_device )
567
572
return self
568
573
569
574
@property
0 commit comments