diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 63085151ec2f..a6a40469e97b 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -18,11 +18,10 @@ Starting from version `0.13.0`, Diffusers supports the latest optimization from ## Installation -To benefit from the accelerated transformers implementation and `torch.compile`, we will need to install the nightly version of PyTorch, as the stable version is yet to be released. The first step is to install CUDA 11.7 or CUDA 11.8, -as PyTorch 2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using: +To benefit from the accelerated attention implementation and `torch.compile`, you just need to install the latest versions of PyTorch 2.0 from `pip`, and make sure you are on diffusers 0.13.0 or later. As explained below, `diffusers` automatically uses the attention optimizations (but not `torch.compile`) when available. ```bash -pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117 +pip install --upgrade torch torchvision diffusers ``` ## Using accelerated transformers and torch.compile. @@ -91,8 +90,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. For the benchmark we used the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got. -The `Speed over xformers` columns denote the speed-up gained over `xFormers` using the `torch.compile+torch.nn.functional.scaled_dot_product_attention`. - +Please refer to [our featured blog post in the PyTorch site](https://pytorch.org/blog/accelerated-diffusers-pt-20/) for more details. ### FP16 benchmark @@ -103,10 +101,14 @@ ___The time reported is in seconds.___ | GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | | --- | --- | --- | --- | --- | --- | --- | -| A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 | -| A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 | -| A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 | -| A100 | 64 | | 52.51 | 53.03 | 47.81 | 8.95 | +| A100 | 1 | 2.69 | 2.7 | 1.98 | 2.47 | 8.52 | +| A100 | 2 | 3.21 | 3.04 | 2.38 | 2.78 | 8.55 | +| A100 | 4 | 5.27 | 3.91 | 3.89 | 3.53 | 9.72 | +| A100 | 8 | 9.74 | 7.03 | 7.04 | 6.62 | 5.83 | +| A100 | 10 | 12.02 | 8.7 | 8.67 | 8.45 | 2.87 | +| A100 | 16 | 18.95 | 13.57 | 13.55 | 13.20 | 2.73 | +| A100 | 32 (1) | OOM | 26.56 | 26.68 | 25.85 | 2.67 | +| A100 | 64 | | 52.51 | 53.03 | 50.93 | 3.01 | | | | | | | | | | A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 | | A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 | @@ -125,25 +127,28 @@ ___The time reported is in seconds.___ | V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 | | V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 | | | | | | | | | -| 3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 | -| 3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 | -| 3090 | 10| 24.08 | 18.7 | 18.7 | 17.69 | 5.40 | -| 3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 | -| 3090 | 32 (1) | | 58.05 | 58 | 54.88 | 5.46 | -| 3090 | 64 (1) | | 126.54 | 126.03 | 117.33 | 7.28 | +| 3090 | 1 | 2.94 | 2.5 | 2.42 | 2.33 | 6.80 | +| 3090 | 4 | 10.04 | 7.82 | 7.72 | 7.38 | 5.63 | +| 3090 | 8 | 19.27 | 14.97 | 14.88 | 14.15 | 5.48 | +| 3090 | 10| 24.08 | 18.7 | 18.62 | 18.12 | 3.10 | +| 3090 | 16 | OOM | 29.06 | 28.88 | 28.2 | 2.96 | +| 3090 | 32 (1) | | 58.05 | 57.42 | 56.28 | 3.05 | +| 3090 | 64 (1) | | 126.54 | 114.27 | 112.21 | 11.32 | | | | | | | | | -| 3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 | -| 3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 | -| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 | -| 3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 | -| 3090 Ti | 32 (1) | | 51.78 | 52.04 | 49.15 | 5.08 | -| 3090 Ti | 64 (1) | | 112.02 | 112.33 | 103.91 | 7.24 | +| 3090 Ti | 1 | 2.7 | 2.26 | 2.19 | 2.12 | 6.19 | +| 3090 Ti | 4 | 9.07 | 7.14 | 7.00 | 6.71 | 6.02 | +| 3090 Ti | 8 | 17.51 | 13.65 | 13.53 | 12.94 | 5.20 | +| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.77 | 16.44 | 2.43 | +| 3090 Ti | 16 | OOM | 26.1 | 26.04 | 25.53 | 2.18 | +| 3090 Ti | 32 (1) | | 51.78 | 51.71 | 50.91 | 1.68 | +| 3090 Ti | 64 (1) | | 112.02 | 102.78 | 100.89 | 9.94 | | | | | | | | | -| 4090 | 4 | 10.48 | 8.37 | 8.32 | 8.01 | 4.30 | -| 4090 | 8 | 14.33 | 10.22 | 10.42 | 9.78 | 4.31 | -| 4090 | 16 | | 17.07 | 17.46 | 17.15 | -0.47 | -| 4090 | 32 (1) | | 39.03 | 39.86 | 37.97 | 2.72 | -| 4090 | 64 (1) | | 77.29 | 79.44 | 77.67 | -0.49 | +| 4090 | 1 | 4.47 | 3.98 | 1.28 | 1.21 | 69.60 | +| 4090 | 4 | 10.48 | 8.37 | 3.76 | 3.56 | 57.47 | +| 4090 | 8 | 14.33 | 10.22 | 7.43 | 6.99 | 31.60 | +| 4090 | 16 | | 17.07 | 14.98 | 14.58 | 14.59 | +| 4090 | 32 (1) | | 39.03 | 30.18 | 29.49 | 24.44 | +| 4090 | 64 (1) | | 77.29 | 61.34 | 59.96 | 22.42 | @@ -155,11 +160,13 @@ Using `torch.compile` in addition to the accelerated transformers implementation | GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) | | --- | --- | --- | --- | --- | --- | --- | --- | -| A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 | -| A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | | -| A100 | 16 | | 47.08 | 46.27 | 44.8 | 4.84 | | -| A100 | 32 | | 92.89 | 91.34 | 88.35 | 4.89 | | -| A100 | 64 | | 185.3 | 182.71 | 176.48 | 4.76 | | +| A100 | 1 | 4.97 | 3.86 | 2.6 | 2.86 | 25.91 | 42.45 | +| A100 | 2 | 9.03 | 6.76 | 4.41 | 4.21 | 37.72 | 53.38 | +| A100 | 4 | 16.70 | 12.42 | 7.94 | 7.54 | 39.29 | 54.85 | +| A100 | 10 | OOM | 29.93 | 18.70 | 18.46 | 38.32 | | +| A100 | 16 | | 47.08 | 29.41 | 29.04 | 38.32 | | +| A100 | 32 | | 92.89 | 57.55 | 56.67 | 38.99 | | +| A100 | 64 | | 185.3 | 114.8 | 112.98 | 39.03 | | | | | | | | | | | A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 | | A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 | @@ -179,30 +186,27 @@ Using `torch.compile` in addition to the accelerated transformers implementation | V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | | | V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | | | | | | | | | | -| 3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 | -| 3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 | -| 3090 | 8 | | 42.59 | 36.75 | 35.59 | 16.44 | | -| 3090 | 16 | | 85.35 | 72.37 | 70.25 | 17.69 | | -| 3090 | 32 (1) | | 162.05 | 138.99 | 134.53 | 16.98 | | -| 3090 | 48 | | 241.91 | 207.75 | | 14.12 | | +| 3090 | 1 | 7.09 | 6.78 | 5.34 | 5.35 | 21.09 | 24.54 | +| 3090 | 4 | 22.69 | 21.45 | 18.56 | 18.18 | 15.24 | 19.88 | +| 3090 | 8 | | 42.59 | 36.68 | 35.61 | 16.39 | | +| 3090 | 16 | | 85.35 | 72.93 | 70.18 | 17.77 | | +| 3090 | 32 (1) | | 162.05 | 143.46 | 138.67 | 14.43 | | | | | | | | | | -| 3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 | -| 3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 | -| 3090 Ti | 8 (2) | | 37.93 | 33.05 | 31.99 | 15.66 | | -| 3090 Ti | 16 | | 75.37 | 65.25 | 64.32 | 14.66 | | -| 3090 Ti | 32 (1) | | 142.55 | 124.44 | 120.74 | 15.30 | | -| 3090 Ti | 48 | | 213.19 | 186.55 | | 12.50 | | +| 3090 Ti | 1 | 6.45 | 6.19 | 4.99 | 4.89 | 21.00 | 24.19 | +| 3090 Ti | 4 | 20.32 | 19.31 | 17.02 | 16.48 | 14.66 | 18.90 | +| 3090 Ti | 8 | | 37.93 | 33.21 | 32.24 | 15.00 | | +| 3090 Ti | 16 | | 75.37 | 66.63 | 64.5 | 14.42 | | +| 3090 Ti | 32 (1) | | 142.55 | 128.89 | 124.92 | 12.37 | | | | | | | | | | -| 4090 | 1 | 5.54 | 4.99 | 4.51 | 4.44 | 11.02 | 19.86 | -| 4090 | 4 | 13.67 | 11.4 | 10.3 | 9.84 | 13.68 | 28.02 | -| 4090 | 8 | | 19.79 | 17.13 | 16.19 | 18.19 | | -| 4090 | 16 | | 38.62 | 33.14 | 32.31 | 16.34 | | -| 4090 | 32 (1) | | 76.57 | 65.96 | 62.05 | 18.96 | | -| 4090 | 48 | | 114.44 | 98.78 | | 13.68 | | - +| 4090 | 1 | 5.54 | 4.99 | 2.66 | 2.58 | 48.30 | 53.43 | +| 4090 | 4 | 13.67 | 11.4 | 8.81 | 8.46 | 25.79 | 38.11 | +| 4090 | 8 | | 19.79 | 17.55 | 16.62 | 16.02 | | +| 4090 | 16 | | 38.62 | 35.65 | 34.07 | 11.78 | | +| 4090 | 32 (1) | | 76.57 | 69.48 | 65.35 | 14.65 | | +| 4090 | 48 | | 114.44 | 106.3 | | 7.11 | | -(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665. -This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64. +(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665. +This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and large batch sizes. -For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303). +For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303) and to [the blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/).