Skip to content

Add option to set dtype in pipeline.to() method #2317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2023
Merged

Add option to set dtype in pipeline.to() method #2317

merged 1 commit into from
Mar 21, 2023

Conversation

1lint
Copy link
Contributor

@1lint 1lint commented Feb 11, 2023

Adds an option to set pipeline dtype along with the device using the pipeline to() method. I added some basic tests to check it does not affect the results, and that the device and dtype can be set individually as well as together in a single to() call.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 11, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

@williamberman could you take a look here?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks good to me!

May I ask what are the expected use cases for this improvement? Loading pipelines in float16 for inference can be done very efficiently passing torch_dtype to from_pretrained, and it will soon be possible to select 16-bit weights for download using the upcoming variant infrastructure: #2305.

@@ -436,16 +442,20 @@ def test_to_device(self):
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)

pipe.to("cpu")
pipe.to("cpu", dtype=torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe change the name of this test too to something like test_to_device_dtype or test_pipeline_to?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll update the test name!

I was going to include the to() dtype method in the checkpoint loading code, so if you had an original stable diffusion checkpoint saved with 32-bit weights, you could conveniently load it into 16-bit weight pipeline after the convert_from_ckpt code (I have a lot of these with the stable-diffusion-webui, and don't have VRAM to fully use 32-bit weights).
There's probably not too many use cases but I thought it would be nice to save a loop through the pipeline components if you wanted to set the device and dtype together. The underlying components supported this call signature so it felt natural to extend it to the pipeline.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I agree that for some uses like the one you described it would be natural to think that dtype is supported too.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me!

@patrickvonplaten
Copy link
Contributor

@williamberman can you also take a look here?

Comment on lines 360 to 365
pipe_to_fp16 = pipe.to(dtype=torch.float16)
output_to_fp16 = pipe_to_fp16(**self.get_dummy_inputs(torch_device))[0]

max_diff = np.abs(output_fp16 - output_to_fp16).max()
self.assertLess(max_diff, 1e-4, "The outputs of the fp16 and to_fp16 pipelines are too different.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A third pipeline here is extraneous. Can we instead do pipe_fp16.to(torch_device, torch.float16) and remove the manually calling half

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes made the change!

Comment on lines 431 to 442
def test_to_device(self):
def test_to_device_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)

pipe.to("cpu")
pipe.to("cpu", dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the to with an fp32 when pipelines are by default generally loaded in fp32 is a bit extraneous. I'd rather we not touch this test and just have a separate test where we test that when loaded the module dtypes are fp32 and then after calling to(fp16), they are in fp16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 here - could we maybe add a new test for this @1lint that'd be very nice !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I added a test_to_dtype test to check pipeline components load as fp32, and change to fp16 after calling to(fp16)

@patrickvonplaten
Copy link
Contributor

@williamberman feel free to merge once you're happy with the PR

@ghost
Copy link

ghost commented Mar 21, 2023

Will you merge this pull request?

@patrickvonplaten
Copy link
Contributor

Thanks for the reminder - let's merge!

@patrickvonplaten patrickvonplaten merged commit b33bd91 into huggingface:main Mar 21, 2023
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants