Skip to content

Commit 146419f

Browse files
All in one Stable Diffusion Pipeline (#821)
* uP * correct * make style * small change
1 parent ad0e9ac commit 146419f

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from typing import Any, Callable, Dict, List, Optional, Union
2+
3+
import torch
4+
5+
import PIL.Image
6+
from diffusers import (
7+
AutoencoderKL,
8+
DDIMScheduler,
9+
DiffusionPipeline,
10+
LMSDiscreteScheduler,
11+
PNDMScheduler,
12+
StableDiffusionImg2ImgPipeline,
13+
StableDiffusionInpaintPipeline,
14+
StableDiffusionPipeline,
15+
UNet2DConditionModel,
16+
)
17+
from diffusers.configuration_utils import FrozenDict
18+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
19+
from diffusers.utils import deprecate, logging
20+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21+
22+
23+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class StableDiffusionMegaPipeline(DiffusionPipeline):
27+
r"""
28+
Pipeline for text-to-image generation using Stable Diffusion.
29+
30+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
31+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
32+
33+
Args:
34+
vae ([`AutoencoderKL`]):
35+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
36+
text_encoder ([`CLIPTextModel`]):
37+
Frozen text-encoder. Stable Diffusion uses the text portion of
38+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
39+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
40+
tokenizer (`CLIPTokenizer`):
41+
Tokenizer of class
42+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
43+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
44+
scheduler ([`SchedulerMixin`]):
45+
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
47+
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
48+
Classification module that estimates whether generated images could be considered offensive or harmful.
49+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
50+
feature_extractor ([`CLIPFeatureExtractor`]):
51+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
52+
"""
53+
54+
def __init__(
55+
self,
56+
vae: AutoencoderKL,
57+
text_encoder: CLIPTextModel,
58+
tokenizer: CLIPTokenizer,
59+
unet: UNet2DConditionModel,
60+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
61+
safety_checker: StableDiffusionSafetyChecker,
62+
feature_extractor: CLIPFeatureExtractor,
63+
):
64+
super().__init__()
65+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
66+
deprecation_message = (
67+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
68+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
69+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
70+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
71+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
72+
" file"
73+
)
74+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
75+
new_config = dict(scheduler.config)
76+
new_config["steps_offset"] = 1
77+
scheduler._internal_dict = FrozenDict(new_config)
78+
79+
self.register_modules(
80+
vae=vae,
81+
text_encoder=text_encoder,
82+
tokenizer=tokenizer,
83+
unet=unet,
84+
scheduler=scheduler,
85+
safety_checker=safety_checker,
86+
feature_extractor=feature_extractor,
87+
)
88+
89+
@property
90+
def components(self) -> Dict[str, Any]:
91+
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
92+
93+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
94+
r"""
95+
Enable sliced attention computation.
96+
97+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
98+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
99+
100+
Args:
101+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
102+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
103+
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
104+
`attention_head_dim` must be a multiple of `slice_size`.
105+
"""
106+
if slice_size == "auto":
107+
# half the attention head size is usually a good trade-off between
108+
# speed and memory
109+
slice_size = self.unet.config.attention_head_dim // 2
110+
self.unet.set_attention_slice(slice_size)
111+
112+
def disable_attention_slicing(self):
113+
r"""
114+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
115+
back to computing attention in one step.
116+
"""
117+
# set slice_size = `None` to disable `attention slicing`
118+
self.enable_attention_slicing(None)
119+
120+
@torch.no_grad()
121+
def inpaint(
122+
self,
123+
prompt: Union[str, List[str]],
124+
init_image: Union[torch.FloatTensor, PIL.Image.Image],
125+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
126+
strength: float = 0.8,
127+
num_inference_steps: Optional[int] = 50,
128+
guidance_scale: Optional[float] = 7.5,
129+
negative_prompt: Optional[Union[str, List[str]]] = None,
130+
num_images_per_prompt: Optional[int] = 1,
131+
eta: Optional[float] = 0.0,
132+
generator: Optional[torch.Generator] = None,
133+
output_type: Optional[str] = "pil",
134+
return_dict: bool = True,
135+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
136+
callback_steps: Optional[int] = 1,
137+
):
138+
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
139+
return StableDiffusionInpaintPipeline(**self.components)(
140+
prompt=prompt,
141+
init_image=init_image,
142+
mask_image=mask_image,
143+
strength=strength,
144+
num_inference_steps=num_inference_steps,
145+
guidance_scale=guidance_scale,
146+
negative_prompt=negative_prompt,
147+
num_images_per_prompt=num_images_per_prompt,
148+
eta=eta,
149+
generator=generator,
150+
output_type=output_type,
151+
return_dict=return_dict,
152+
callback=callback,
153+
)
154+
155+
@torch.no_grad()
156+
def img2img(
157+
self,
158+
prompt: Union[str, List[str]],
159+
init_image: Union[torch.FloatTensor, PIL.Image.Image],
160+
strength: float = 0.8,
161+
num_inference_steps: Optional[int] = 50,
162+
guidance_scale: Optional[float] = 7.5,
163+
negative_prompt: Optional[Union[str, List[str]]] = None,
164+
num_images_per_prompt: Optional[int] = 1,
165+
eta: Optional[float] = 0.0,
166+
generator: Optional[torch.Generator] = None,
167+
output_type: Optional[str] = "pil",
168+
return_dict: bool = True,
169+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
170+
callback_steps: Optional[int] = 1,
171+
**kwargs,
172+
):
173+
# For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
174+
return StableDiffusionImg2ImgPipeline(**self.components)(
175+
prompt=prompt,
176+
init_image=init_image,
177+
strength=strength,
178+
num_inference_steps=num_inference_steps,
179+
guidance_scale=guidance_scale,
180+
negative_prompt=negative_prompt,
181+
num_images_per_prompt=num_images_per_prompt,
182+
eta=eta,
183+
generator=generator,
184+
output_type=output_type,
185+
return_dict=return_dict,
186+
callback=callback,
187+
callback_steps=callback_steps,
188+
)
189+
190+
@torch.no_grad()
191+
def text2img(
192+
self,
193+
prompt: Union[str, List[str]],
194+
height: int = 512,
195+
width: int = 512,
196+
num_inference_steps: int = 50,
197+
guidance_scale: float = 7.5,
198+
negative_prompt: Optional[Union[str, List[str]]] = None,
199+
num_images_per_prompt: Optional[int] = 1,
200+
eta: float = 0.0,
201+
generator: Optional[torch.Generator] = None,
202+
latents: Optional[torch.FloatTensor] = None,
203+
output_type: Optional[str] = "pil",
204+
return_dict: bool = True,
205+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
206+
callback_steps: Optional[int] = 1,
207+
):
208+
# For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
209+
return StableDiffusionPipeline(**self.components)(
210+
prompt=prompt,
211+
height=height,
212+
width=width,
213+
num_inference_steps=num_inference_steps,
214+
guidance_scale=guidance_scale,
215+
negative_prompt=negative_prompt,
216+
num_images_per_prompt=num_images_per_prompt,
217+
eta=eta,
218+
generator=generator,
219+
latents=latents,
220+
output_type=output_type,
221+
return_dict=return_dict,
222+
callback=callback,
223+
callback_steps=callback_steps,
224+
)

0 commit comments

Comments
 (0)