Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/convert_gligen_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,6 @@ def convert_gligen_to_diffusers(
)

if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)

pipe.save_pretrained(args.dump_path)
2 changes: 1 addition & 1 deletion scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
)

if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)

if args.controlnet:
# only save the controlnet model
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_zero123_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,6 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
)

if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)

pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
30 changes: 4 additions & 26 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,32 +775,10 @@ def to(self, *args, **kwargs):
Returns:
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
"""

torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
deprecate("torch_dtype", "0.27.0", "")
torch_device = kwargs.pop("torch_device", None)
if torch_device is not None:
deprecate("torch_device", "0.27.0", "")

dtype_kwarg = kwargs.pop("dtype", None)
device_kwarg = kwargs.pop("device", None)
dtype = kwargs.pop("dtype", None)
device = kwargs.pop("device", None)
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)

if torch_dtype is not None and dtype_kwarg is not None:
raise ValueError(
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
)

dtype = torch_dtype or dtype_kwarg

if torch_device is not None and device_kwarg is not None:
raise ValueError(
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
)

device = torch_device or device_kwarg

dtype_arg = None
device_arg = None
if len(args) == 1:
Expand Down Expand Up @@ -873,12 +851,12 @@ def module_is_offloaded(module):

if is_loaded_in_8bit and dtype is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
)

if is_loaded_in_8bit and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
)
else:
module.to(device, dtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/animatediff/test_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))

pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))

pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def test_to_dtype(self):
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))

# Once we send to fp16, all params are in half-precision, including the logit scale
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/musicldm/test_musicldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def test_to_dtype(self):
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))

# Once we send to fp16, all params are in half-precision, including the logit scale
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/pia/test_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))

pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))

pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ def test_pipe_to(self):
sd1 = sd.to(torch.float16)
sd2 = sd.to(None, torch.float16)
sd3 = sd.to(dtype=torch.float16)
sd4 = sd.to(torch_dtype=torch.float16)
sd4 = sd.to(dtype=torch.float16)
sd5 = sd.to(None, dtype=torch.float16)
sd6 = sd.to(None, torch_dtype=torch.float16)

Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))

pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

Expand Down