diff --git a/scripts/convert_gligen_to_diffusers.py b/scripts/convert_gligen_to_diffusers.py index 30d789b60634..83c1f928e407 100644 --- a/scripts/convert_gligen_to_diffusers.py +++ b/scripts/convert_gligen_to_diffusers.py @@ -576,6 +576,6 @@ def convert_gligen_to_diffusers( ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 2ca70963d132..980446179cfd 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -179,7 +179,7 @@ ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) if args.controlnet: # only save the controlnet model diff --git a/scripts/convert_zero123_to_diffusers.py b/scripts/convert_zero123_to_diffusers.py index f016312b8bb6..3bb6f6c041c9 100644 --- a/scripts/convert_zero123_to_diffusers.py +++ b/scripts/convert_zero123_to_diffusers.py @@ -801,6 +801,6 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 06187645f000..769fcd2e832a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -775,32 +775,10 @@ def to(self, *args, **kwargs): Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ - - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - deprecate("torch_dtype", "0.27.0", "") - torch_device = kwargs.pop("torch_device", None) - if torch_device is not None: - deprecate("torch_device", "0.27.0", "") - - dtype_kwarg = kwargs.pop("dtype", None) - device_kwarg = kwargs.pop("device", None) + dtype = kwargs.pop("dtype", None) + device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) - if torch_dtype is not None and dtype_kwarg is not None: - raise ValueError( - "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`." - ) - - dtype = torch_dtype or dtype_kwarg - - if torch_device is not None and device_kwarg is not None: - raise ValueError( - "You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`." - ) - - device = torch_device or device_kwarg - dtype_arg = None device_arg = None if len(args) == 1: @@ -873,12 +851,12 @@ def module_is_offloaded(module): if is_loaded_in_8bit and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(device, dtype) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 80a8fd19f5a0..525ca24bbd9a 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -218,7 +218,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 3226bdb3ca6e..767fc30b4eb5 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -224,7 +224,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 60ef86518e35..e2655515bc40 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -483,7 +483,7 @@ def test_to_dtype(self): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index 4bf03569bbf3..fe78ab6acbb1 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -400,7 +400,7 @@ def test_to_dtype(self): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index eb76457abc9d..edd129560c63 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -231,7 +231,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 871266fb9c24..60c411283803 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -396,7 +396,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 32ae81ddc2d8..bd9f42f185e6 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1623,7 +1623,7 @@ def test_pipe_to(self): sd1 = sd.to(torch.float16) sd2 = sd.to(None, torch.float16) sd3 = sd.to(dtype=torch.float16) - sd4 = sd.to(torch_dtype=torch.float16) + sd4 = sd.to(dtype=torch.float16) sd5 = sd.to(None, dtype=torch.float16) sd6 = sd.to(None, torch_dtype=torch.float16) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e3c8a4ef503f..7f51847caf07 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -716,7 +716,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))