-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add option to set dtype in pipeline.to() method #2317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@williamberman could you take a look here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks good to me!
May I ask what are the expected use cases for this improvement? Loading pipelines in float16
for inference can be done very efficiently passing torch_dtype
to from_pretrained
, and it will soon be possible to select 16-bit weights for download using the upcoming variant
infrastructure: #2305.
tests/test_pipelines_common.py
Outdated
@@ -436,16 +442,20 @@ def test_to_device(self): | |||
pipe = self.pipeline_class(**components) | |||
pipe.set_progress_bar_config(disable=None) | |||
|
|||
pipe.to("cpu") | |||
pipe.to("cpu", dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe change the name of this test too to something like test_to_device_dtype
or test_pipeline_to
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll update the test name!
I was going to include the to() dtype method in the checkpoint loading code, so if you had an original stable diffusion checkpoint saved with 32-bit weights, you could conveniently load it into 16-bit weight pipeline after the convert_from_ckpt code (I have a lot of these with the stable-diffusion-webui, and don't have VRAM to fully use 32-bit weights).
There's probably not too many use cases but I thought it would be nice to save a loop through the pipeline components if you wanted to set the device and dtype together. The underlying components supported this call signature so it felt natural to extend it to the pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I agree that for some uses like the one you described it would be natural to think that dtype
is supported too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me!
@williamberman can you also take a look here? |
tests/test_pipelines_common.py
Outdated
pipe_to_fp16 = pipe.to(dtype=torch.float16) | ||
output_to_fp16 = pipe_to_fp16(**self.get_dummy_inputs(torch_device))[0] | ||
|
||
max_diff = np.abs(output_fp16 - output_to_fp16).max() | ||
self.assertLess(max_diff, 1e-4, "The outputs of the fp16 and to_fp16 pipelines are too different.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A third pipeline here is extraneous. Can we instead do pipe_fp16.to(torch_device, torch.float16)
and remove the manually calling half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes made the change!
tests/test_pipelines_common.py
Outdated
def test_to_device(self): | ||
def test_to_device_dtype(self): | ||
components = self.get_dummy_components() | ||
pipe = self.pipeline_class(**components) | ||
pipe.set_progress_bar_config(disable=None) | ||
|
||
pipe.to("cpu") | ||
pipe.to("cpu", dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the to with an fp32 when pipelines are by default generally loaded in fp32 is a bit extraneous. I'd rather we not touch this test and just have a separate test where we test that when loaded the module dtypes are fp32 and then after calling to(fp16), they are in fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 here - could we maybe add a new test for this @1lint that'd be very nice !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I added a test_to_dtype test to check pipeline components load as fp32, and change to fp16 after calling to(fp16)
@williamberman feel free to merge once you're happy with the PR |
Will you merge this pull request? |
Thanks for the reminder - let's merge! |
add test_to_dtype to check pipe.to(fp16)
add test_to_dtype to check pipe.to(fp16)
add test_to_dtype to check pipe.to(fp16)
Adds an option to set pipeline dtype along with the device using the pipeline
to()
method. I added some basic tests to check it does not affect the results, and that the device and dtype can be set individually as well as together in a singleto()
call.