-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Optimize Stable Diffusion #371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dceb05c
e7ff942
45696fd
d5edb19
89143b1
2f4a34b
e6e41ae
95051ae
0b85b6f
101b8b0
493a64a
75b4f0c
1639f69
f2176e9
08db0c3
de714e7
98a5301
9647752
cab7b28
acb8397
39994cc
d30f968
0c70c0e
e422eb3
4e67675
cec5928
c0dd0e9
006ccb8
75fa029
31c58ea
2fa9c69
47c668c
cc9bc13
419fde3
9312809
03a2ee7
d0b5579
3bdf1ed
aeddb45
7d0347f
f409172
76dda3e
8929d76
318bdd0
98e80da
9b1ec08
61ed4c3
fe6eed5
e3c38e8
f25f1c1
00d5a51
0cd4613
0fb42d4
e6969ed
e011d15
2ad3353
7183202
da67fe6
c8cc2ba
b6162dc
9a1fb03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,10 +37,12 @@ def get_timestep_embedding( | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | ||
|
|
||
| half_dim = embedding_dim // 2 | ||
| exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) | ||
| exponent = -math.log(max_period) * torch.arange( | ||
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | ||
|
Comment on lines
+40
to
+41
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this opereration to run in fp32 even when the pipeline runs in fp16?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to try it out - let's maybe do it in a follow-up PR? :-) |
||
| ) | ||
| exponent = exponent / (half_dim - downscale_freq_shift) | ||
|
|
||
| emb = torch.exp(exponent).to(device=timesteps.device) | ||
NouamaneTazi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| emb = torch.exp(exponent) | ||
| emb = timesteps[:, None].float() * emb[None, :] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as previous comment |
||
|
|
||
| # scale embeddings | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -230,16 +230,16 @@ def forward( | |
| # 1. time | ||
| timesteps = timestep | ||
| if not torch.is_tensor(timesteps): | ||
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | ||
| timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need int64 tensors for timesteps, no matter the pipeline's precision?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least for all the Stable Diffusion applications I've seen so far, timesteps are ints in the range 0..1000. Even if other diffusion models do several orders of magnitude more than that, you'd think Unless there's some byte alignment optimization reason to specifically make them 64-bit?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can leave timesteps as int32 |
||
| elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: | ||
| timesteps = timesteps.to(dtype=torch.float32) | ||
| timesteps = timesteps[None].to(device=sample.device) | ||
| timesteps = timesteps[None].to(sample.device) | ||
|
|
||
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
| timesteps = timesteps.expand(sample.shape[0]) | ||
|
|
||
| t_emb = self.time_proj(timesteps) | ||
| emb = self.time_embedding(t_emb) | ||
| emb = self.time_embedding(t_emb.to(self.dtype)) | ||
|
|
||
| # 2. pre-process | ||
| sample = self.conv_in(sample) | ||
|
|
@@ -279,7 +279,7 @@ def forward( | |
| # 6. post-process | ||
| # make sure hidden states is in float32 | ||
| # when running in half-precision | ||
| sample = self.conv_norm_out(sample.float()).type(sample.dtype) | ||
NouamaneTazi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sample = self.conv_norm_out(sample) | ||
| sample = self.conv_act(sample) | ||
| sample = self.conv_out(sample) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -225,15 +225,23 @@ def __call__( | |
| latents_shape, | ||
| generator=generator, | ||
| device=latents_device, | ||
| dtype=text_embeddings.dtype, | ||
| ) | ||
| else: | ||
| if latents.shape != latents_shape: | ||
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
| latents = latents.to(self.device) | ||
| latents = latents.to(latents_device) | ||
|
|
||
| # set timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
|
||
| # Some schedulers like PNDM have timesteps as arrays | ||
| # It's more optimzed to move all timesteps to correct device beforehand | ||
| if torch.is_tensor(self.scheduler.timesteps): | ||
| timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
| else: | ||
| timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make the dtype for timesteps |
||
|
|
||
| # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| latents = latents * self.scheduler.sigmas[0] | ||
|
|
@@ -247,7 +255,7 @@ def __call__( | |
| if accepts_eta: | ||
| extra_step_kwargs["eta"] = eta | ||
|
|
||
| for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): | ||
| for i, t in enumerate(self.progress_bar(timesteps_tensor)): | ||
| # expand the latents if we are doing classifier free guidance | ||
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
|
|
@@ -278,7 +286,9 @@ def __call__( | |
|
|
||
| # run safety checker | ||
| safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) | ||
| image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) | ||
| image, has_nsfw_concept = self.safety_checker( | ||
| images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) | ||
| ) | ||
|
|
||
| if output_type == "pil": | ||
| image = self.numpy_to_pil(image) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.