From 70abbb7be55be2c7dfc76cc0396530c924f65048 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Sep 2022 15:11:51 +0000 Subject: [PATCH 1/7] Removing `autocast` for `35-25% speedup`. --- README.md | 29 +++++---------- src/diffusers/models/resnet.py | 4 --- src/diffusers/models/unet_2d_condition.py | 9 +++-- src/diffusers/pipelines/README.md | 36 ++----------------- .../pipelines/stable_diffusion/README.md | 12 ++----- .../pipeline_stable_diffusion.py | 10 +++++- 6 files changed, 29 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 7be7f9f84257..25448472df9d 100644 --- a/README.md +++ b/README.md @@ -76,15 +76,13 @@ You need to accept the model license before downloading or using the Stable Diff ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -104,8 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` If you are limited by GPU memory, you might want to consider using the model in `fp16` as @@ -123,8 +120,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" pipe.enable_attention_slicing() -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -149,8 +145,7 @@ pipe = StableDiffusionPipeline.from_pretrained( pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -160,7 +155,6 @@ image.save("astronaut_rides_horse.png") The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python -from torch import autocast import requests import torch from PIL import Image @@ -190,8 +184,7 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images +images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -204,7 +197,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by ```python from io import BytesIO -from torch import autocast import torch import requests import PIL @@ -234,8 +226,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = pipe.to(device) prompt = "a cat sitting on a bench" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` @@ -258,7 +249,6 @@ If you want to run the code yourself 💻, you can try out: - [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) ```python # !pip install diffusers transformers -from torch import autocast from diffusers import DiffusionPipeline device = "cuda" @@ -270,8 +260,7 @@ ldm = ldm.to(device) # run pipeline in inference (sample random noise and denoise) prompt = "A painting of a squirrel eating a burger" -with autocast(device): - image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0] +image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0] # save image image.save("squirrel.png") @@ -279,7 +268,6 @@ image.save("squirrel.png") - [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256) ```python # !pip install diffusers -from torch import autocast from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline model_id = "google/ddpm-celebahq-256" @@ -290,8 +278,7 @@ ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline wi ddpm.to(device) # run pipeline in inference (sample random noise and denoise) -with autocast("cuda"): - image = ddpm().images[0] +image = ddpm().images[0] # save image image.save("ddpm_generated_image.png") diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 507ca8632de3..d97f62cb40f9 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -331,8 +331,6 @@ def __init__( def forward(self, x, temb): hidden_states = x - # make sure hidden states is in float32 - # when running in half-precision hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) @@ -349,8 +347,6 @@ def forward(self, x, temb): temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] hidden_states = hidden_states + temb - # make sure hidden states is in float32 - # when running in half-precision hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 92caaca92e24..e96bbebb4749 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -222,6 +222,11 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb) # 2. pre-process @@ -258,9 +263,7 @@ def forward( sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 328c37dc765f..f529697dc136 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -86,15 +86,13 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).images[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -104,33 +102,7 @@ image.save("astronaut_rides_horse.png") The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python -from torch import autocast -import requests -from PIL import Image -from io import BytesIO - -from diffusers import StableDiffusionImg2ImgPipeline - -# load the pipeline -device = "cuda" -pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - revision="fp16", - torch_dtype=torch.float16, - use_auth_token=True -).to(device) - -# let's download an initial image -url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") -init_image = init_image.resize((768, 512)) - -prompt = "A fantasy landscape, trending on artstation" - -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images +images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -148,7 +120,6 @@ The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by ```python from io import BytesIO -from torch import autocast import requests import PIL @@ -173,8 +144,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( ).to(device) prompt = "a cat sitting on a bench" -with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images +images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index 63ad90fad2f5..a8428896585f 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -59,15 +59,13 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-4") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).sample[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -76,7 +74,6 @@ image.save("astronaut_rides_horse.png") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, DDIMScheduler scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) @@ -88,8 +85,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).sample[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -98,7 +94,6 @@ image.save("astronaut_rides_horse.png") ```python # make sure you're logged in with `huggingface-cli login` -from torch import autocast from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler lms = LMSDiscreteScheduler( @@ -114,8 +109,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -with autocast("cuda"): - image = pipe(prompt).sample[0] +image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f02fa114a8e1..b3897a13a46d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -205,16 +205,18 @@ def __call__( # However this currently doesn't work in `mps`. latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype if latents is None: latents = torch.randn( latents_shape, generator=generator, device=latents_device, + dtype=latents_dtype ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) @@ -268,6 +270,12 @@ def __call__( # run safety checker safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + + # XXX: it might be better to check against the actual dtype of the safety checker since + # it might want to run in a different precision, but the safety checker does not expose + # a `dtype` /`precision` itself, so this is a good enough proxy for running pipelines in + # both f16 / f32 + safety_cheker_input.pixel_values = safety_cheker_input.pixel_values.to(dtype=latents_dtype) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": From 6334170c07da0027d395c1b41301ba8defd4575f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Sep 2022 15:24:49 +0000 Subject: [PATCH 2/7] iQuality --- .../stable_diffusion/pipeline_stable_diffusion.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b3897a13a46d..39557bce13e1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -207,12 +207,7 @@ def __call__( latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=latents_device, - dtype=latents_dtype - ) + latents = torch.randn(latents_shape, generator=generator, device=latents_device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") From e4f93880a1a089d62a690e5c02730a0e950996b9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Sep 2022 08:35:55 +0000 Subject: [PATCH 3/7] Adding a slow test. --- tests/test_pipelines.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 3d691368104e..a2ec5ef59add 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1137,6 +1137,37 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_f16_no_autocast(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True + ).to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 2e-2 + @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_img2img_pipeline(self): From 6b853af1254850c7de4b20fb79e0bf301a2b5fb2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Sep 2022 08:42:00 +0000 Subject: [PATCH 4/7] Fixing mps noise generation. --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 39557bce13e1..bbcb3a25cf14 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -203,15 +203,17 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - latents = torch.randn(latents_shape, generator=generator, device=latents_device, dtype=latents_dtype) + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn(latents_shape, generator=generator, device='cpu', dtype=latents_dtype).to(self.device) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From ce66d6096d1abeee50ee7e4b47c07937003698d1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Sep 2022 08:47:25 +0000 Subject: [PATCH 5/7] Raising error on wrong device, instead of just casting on behalf of user. --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bbcb3a25cf14..34aae7d7974f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -214,6 +214,8 @@ def __call__( else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if latents.device != self.device: + raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From 43c2d1796f1e9e6f9fc9e8cf06dc2396add1f0cc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Sep 2022 08:49:19 +0000 Subject: [PATCH 6/7] Quality. --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 34aae7d7974f..d2127fb6b628 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -206,9 +206,11 @@ def __call__( latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: - if self.device.type == "mps": + if self.device.type == "mps": # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device='cpu', dtype=latents_dtype).to(self.device) + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) else: latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) else: From cf393ebdae7ef85541175538acf6a0938243944d Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 5 Oct 2022 12:56:06 +0000 Subject: [PATCH 7/7] fix merge --- src/diffusers/pipelines/README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index df2dabe46f7d..71841e023372 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -102,6 +102,30 @@ image.save("astronaut_rides_horse.png") The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python +import requests +from PIL import Image +from io import BytesIO + +from diffusers import StableDiffusionImg2ImgPipeline + +# load the pipeline +device = "cuda" +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=True +).to(device) + +# let's download an initial image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png")