Skip to content

Commit 12004bf

Browse files
[Community Pipelines]Accelerate inference of stable diffusion xl (SDXL) by IPEX on CPU (#6683)
* add stable_diffusion_xl_ipex community pipeline * make style for code quality check * update docs as suggested --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d2fc5eb commit 12004bf

File tree

2 files changed

+1596
-0
lines changed

2 files changed

+1596
-0
lines changed

examples/community/README.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
6363
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
6464
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
6565
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
66+
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
6667

6768
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
6869

@@ -1707,6 +1708,111 @@ print("Latency of StableDiffusionPipeline--fp32",latency)
17071708

17081709
```
17091710

1711+
### Stable Diffusion XL on IPEX
1712+
1713+
This diffusion pipeline aims to accelarate the inference of Stable-Diffusion XL on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
1714+
1715+
To use this pipeline, you need to:
1716+
1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
1717+
1718+
**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance.
1719+
1720+
|PyTorch Version|IPEX Version|
1721+
|--|--|
1722+
|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|
1723+
|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
1724+
1725+
You can simply use pip to install IPEX with the latest version.
1726+
```python
1727+
python -m pip install intel_extension_for_pytorch
1728+
```
1729+
**Note:** To install a specific version, run with the following command:
1730+
```
1731+
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
1732+
```
1733+
1734+
2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
1735+
1736+
**Note:** The values of `height` and `width` used during preparation with `prepare_for_ipex()` should be the same when running inference with the prepared pipeline.
1737+
1738+
```python
1739+
pipe = StableDiffusionXLPipelineIpex.from_pretrained("stabilityai/sdxl-turbo", low_cpu_mem_usage=True, use_safetensors=True)
1740+
# value of image height/width should be consistent with the pipeline inference
1741+
# For Float32
1742+
pipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
1743+
# For BFloat16
1744+
pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
1745+
```
1746+
1747+
Then you can use the ipex pipeline in a similar way to the default stable diffusion xl pipeline.
1748+
```python
1749+
# value of image height/width should be consistent with 'prepare_for_ipex()'
1750+
# For Float32
1751+
image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
1752+
# For BFloat16
1753+
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
1754+
image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]
1755+
```
1756+
1757+
The following code compares the performance of the original stable diffusion xl pipeline with the ipex-optimized pipeline.
1758+
By using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,
1759+
code-named Sapphire Rapids.
1760+
1761+
```python
1762+
import torch
1763+
from diffusers import StableDiffusionXLPipeline
1764+
from pipeline_stable_diffusion_xl_ipex import StableDiffusionXLPipelineIpex
1765+
import time
1766+
1767+
prompt = "sailing ship in storm by Rembrandt"
1768+
model_id = "stabilityai/sdxl-turbo"
1769+
steps = 4
1770+
1771+
# Helper function for time evaluation
1772+
def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
1773+
# warmup
1774+
for _ in range(2):
1775+
images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0).images
1776+
#time evaluation
1777+
start = time.time()
1778+
for _ in range(nb_pass):
1779+
pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0)
1780+
end = time.time()
1781+
return (end - start) / nb_pass
1782+
1783+
############## bf16 inference performance ###############
1784+
1785+
# 1. IPEX Pipeline initialization
1786+
pipe = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
1787+
pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)
1788+
1789+
# 2. Original Pipeline initialization
1790+
pipe2 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
1791+
1792+
# 3. Compare performance between Original Pipeline and IPEX Pipeline
1793+
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
1794+
latency = elapsed_time(pipe, num_inference_steps=steps)
1795+
print("Latency of StableDiffusionXLPipelineIpex--bf16", latency, "s for total", steps, "steps")
1796+
latency = elapsed_time(pipe2, num_inference_steps=steps)
1797+
print("Latency of StableDiffusionXLPipeline--bf16", latency, "s for total", steps, "steps")
1798+
1799+
############## fp32 inference performance ###############
1800+
1801+
# 1. IPEX Pipeline initialization
1802+
pipe3 = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
1803+
pipe3.prepare_for_ipex(torch.float32, prompt, height=512, width=512)
1804+
1805+
# 2. Original Pipeline initialization
1806+
pipe4 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)
1807+
1808+
# 3. Compare performance between Original Pipeline and IPEX Pipeline
1809+
latency = elapsed_time(pipe3, num_inference_steps=steps)
1810+
print("Latency of StableDiffusionXLPipelineIpex--fp32", latency, "s for total", steps, "steps")
1811+
latency = elapsed_time(pipe4, num_inference_steps=steps)
1812+
print("Latency of StableDiffusionXLPipeline--fp32",latency, "s for total", steps, "steps")
1813+
1814+
```
1815+
17101816
### CLIP Guided Images Mixing With Stable Diffusion
17111817

17121818
![clip_guided_images_mixing_examples](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/main.png)

0 commit comments

Comments
 (0)