diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx
index 064bc58f8c2b..bb58c19b49a4 100644
--- a/docs/source/optimization/fp16.mdx
+++ b/docs/source/optimization/fp16.mdx
@@ -14,7 +14,64 @@ specific language governing permissions and limitations under the License.
We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed.
-## CUDA `autocast`
+
+
+ |
+ | Latency
+ | Speedup
+ |
+
+ | original
+ | 9.50s
+ | x1
+ |
+
+ | cuDNN auto-tuner
+ | 9.37s
+ | x1.01
+ |
+ | autocast (fp16)
+ | 5.47s
+ | x1.91
+ |
+ | fp16
+ | 3.61s
+ | x2.91
+ |
+ | channels last
+ | 3.30s
+ | x2.87
+ |
+
+ | traced UNet
+ | 3.21s
+ | x2.96
+ |
+obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps.
+
+## Enable cuDNN auto-tuner
+
+[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)Â supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size.
+
+Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting:
+
+```python
+import torch
+
+torch.backends.cudnn.benchmark = True
+```
+
+### Use tf32 instead of fp32 (on Ampere and later CUDA devices)
+
+On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference:
+
+```python
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+```
+
+## Automatic mixed precision (AMP)
If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example:
@@ -47,7 +104,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
## Sliced attention for additional memory savings
-For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
+For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory.
@@ -73,4 +130,139 @@ with torch.autocast("cuda"):
image = pipe(prompt).images[0]
```
-There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
+There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
+
+## Using Channels Last memory format
+
+Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
+
+For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following:
+
+```python
+print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
+pipe.unet.to(memory_format=torch.channels_last) # in-place operation
+print(
+ pipe.unet.conv_out.state_dict()["weight"].stride()
+) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works
+```
+
+## Tracing
+
+Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation.
+
+To trace our UNet model, we can use the following:
+
+```python
+import time
+import torch
+from diffusers import StableDiffusionPipeline
+import functools
+
+# torch disable grad
+torch.set_grad_enabled(False)
+
+# set variables
+n_experiments = 2
+unet_runs_per_experiment = 50
+
+# load inputs
+def generate_inputs():
+ sample = torch.randn(2, 4, 64, 64).half().cuda()
+ timestep = torch.rand(1).half().cuda() * 999
+ encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
+ return sample, timestep, encoder_hidden_states
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # scheduler=scheduler,
+ use_auth_token=True,
+ revision="fp16",
+ torch_dtype=torch.float16,
+).to("cuda")
+unet = pipe.unet
+unet.eval()
+unet.to(memory_format=torch.channels_last) # use channels_last memory format
+unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
+
+# warmup
+for _ in range(3):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet(*inputs)
+
+# trace
+print("tracing..")
+unet_traced = torch.jit.trace(unet, inputs)
+unet_traced.eval()
+print("done tracing")
+
+
+# warmup and optimize graph
+for _ in range(5):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet_traced(*inputs)
+
+
+# benchmarking
+with torch.inference_mode():
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet_traced(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet inference took {time.time() - start_time:.2f} seconds")
+
+# save the model
+unet_traced.save("unet_traced.pt")
+```
+
+Then we can replace the `unet` attribute of the pipeline with the traced model like the following
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+from dataclasses import dataclass
+
+
+@dataclass
+class UNet2DConditionOutput:
+ sample: torch.FloatTensor
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # scheduler=scheduler,
+ use_auth_token=True,
+ revision="fp16",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# use jitted unet
+unet_traced = torch.jit.load("unet_traced.pt")
+# del pipe.unet
+class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+
+
+pipe.unet = TracedUNet()
+
+with torch.inference_mode():
+ image = pipe([prompt] * 1, num_inference_steps=50).images[0]
+```
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index f963310f12eb..b4e5f2e07f7d 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -72,8 +72,7 @@ def forward(self, hidden_states):
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
-
- attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
@@ -275,7 +274,13 @@ def forward(self, hidden_states, context=None, mask=None):
return self.to_out(hidden_states)
def _attention(self, query, key, value):
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
@@ -292,7 +297,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
- attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
+ attn_slice = (
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
+ ) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 86ac074c1d0e..06b814e2bbcd 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -37,10 +37,12 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
- exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
exponent = exponent / (half_dim - downscale_freq_shift)
- emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 43c00fdf7003..2a6b2971aae5 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -331,7 +331,7 @@ def forward(self, x, temb):
# make sure hidden states is in float32
# when running in half-precision
- hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
@@ -349,7 +349,7 @@ def forward(self, x, temb):
# make sure hidden states is in float32
# when running in half-precision
- hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
+ hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index 5e3ee091c311..3ea8829b48e1 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -230,16 +230,16 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
- timesteps = timesteps.to(dtype=torch.float32)
- timesteps = timesteps[None].to(device=sample.device)
+ timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
- emb = self.time_embedding(t_emb)
+ emb = self.time_embedding(t_emb.to(self.dtype))
# 2. pre-process
sample = self.conv_in(sample)
@@ -279,7 +279,7 @@ def forward(
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
- sample = self.conv_norm_out(sample.float()).type(sample.dtype)
+ sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 77f25ef1b9c5..5c6890db82fd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -225,15 +225,23 @@ def __call__(
latents_shape,
generator=generator,
device=latents_device,
+ dtype=text_embeddings.dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
- latents = latents.to(self.device)
+ latents = latents.to(latents_device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimzed to move all timesteps to correct device beforehand
+ if torch.is_tensor(self.scheduler.timesteps):
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+ else:
+ timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
+
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
@@ -247,7 +255,7 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -278,7 +286,9 @@ def __call__(
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index f2ccee71c024..8f2bb67b8de1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -265,7 +265,11 @@ def __call__(
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimzed to move all timesteps to correct device beforehand
+ timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index a95f9152279a..2e792df1803e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -298,7 +298,11 @@ def __call__(
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
- for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimzed to move all timesteps to correct device beforehand
+ timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
+
+ for i, t in tqdm(enumerate(timesteps_tensor)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index ec445e729fbe..8fd8c2b844a8 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -131,13 +131,15 @@ def lms_derivative(tau):
return integrated_coeff
- def set_timesteps(self, num_inference_steps: int):
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
@@ -145,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
- self.sigmas = torch.from_numpy(sigmas)
- self.timesteps = torch.from_numpy(timesteps)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = []