Skip to content

Commit bde7046

Browse files
patrickvonplatenPrathik Rao
authored andcommitted
[Stable Diffusion] Add components function (huggingface#889)
* [Stable Diffusion] Add components function * uP
1 parent 23579b9 commit bde7046

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

docs/source/api/diffusion_pipeline.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain
3232
[[autodoc]] DiffusionPipeline
3333
- from_pretrained
3434
- save_pretrained
35+
- to
36+
- device
37+
- components
3538

3639
## ImagePipelineOutput
3740
By default diffusion pipelines return an object of class

docs/source/api/pipelines/stable_diffusion.mdx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,26 @@ For more details about how Stable Diffusion works and how it differs from the ba
1717
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
1818
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental***Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
1919

20+
## Tips
21+
22+
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
23+
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
24+
- Make use of the `components` functionality to instantiate all components in the most memory-efficient way:
25+
26+
```python
27+
>>> from diffusers import (
28+
... StableDiffusionPipeline,
29+
... StableDiffusionImg2ImgPipeline,
30+
... StableDiffusionInpaintPipeline,
31+
... )
32+
33+
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
34+
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
35+
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
36+
37+
>>> # now you can use img2text(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
38+
```
39+
2040
## StableDiffusionPipelineOutput
2141
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
2242

src/diffusers/pipeline_utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import os
2020
from dataclasses import dataclass
21-
from typing import List, Optional, Union
21+
from typing import Any, Dict, List, Optional, Union
2222

2323
import numpy as np
2424
import torch
@@ -564,6 +564,41 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
564564
model = pipeline_class(**init_kwargs)
565565
return model
566566

567+
@property
568+
def components(self) -> Dict[str, Any]:
569+
r"""
570+
571+
The `self.compenents` property can be useful to run different pipelines with the same weights and
572+
configurations to not have to re-allocate memory.
573+
574+
Examples:
575+
576+
```py
577+
>>> from diffusers import (
578+
... StableDiffusionPipeline,
579+
... StableDiffusionImg2ImgPipeline,
580+
... StableDiffusionInpaintPipeline,
581+
... )
582+
583+
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
584+
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
585+
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
586+
```
587+
588+
Returns:
589+
A dictionaly containing all the modules needed to initialize the pipleline.
590+
"""
591+
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
592+
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
593+
594+
if set(components.keys()) != expected_modules:
595+
raise ValueError(
596+
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
597+
f" {expected_modules} to be defined, but {components} are defined."
598+
)
599+
600+
return components
601+
567602
@staticmethod
568603
def numpy_to_pil(images):
569604
"""

tests/test_pipelines.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,59 @@ def test_stable_diffusion_inpaint_fp16(self):
13911391

13921392
assert image.shape == (1, 128, 128, 3)
13931393

1394+
def test_components(self):
1395+
"""Test that components property works correctly"""
1396+
unet = self.dummy_cond_unet
1397+
scheduler = PNDMScheduler(skip_prk_steps=True)
1398+
vae = self.dummy_vae
1399+
bert = self.dummy_text_encoder
1400+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1401+
1402+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
1403+
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
1404+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
1405+
1406+
# make sure here that pndm scheduler skips prk
1407+
inpaint = StableDiffusionInpaintPipeline(
1408+
unet=unet,
1409+
scheduler=scheduler,
1410+
vae=vae,
1411+
text_encoder=bert,
1412+
tokenizer=tokenizer,
1413+
safety_checker=self.dummy_safety_checker,
1414+
feature_extractor=self.dummy_extractor,
1415+
)
1416+
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components)
1417+
text2img = StableDiffusionPipeline(**inpaint.components)
1418+
1419+
prompt = "A painting of a squirrel eating a burger"
1420+
generator = torch.Generator(device=torch_device).manual_seed(0)
1421+
image_inpaint = inpaint(
1422+
[prompt],
1423+
generator=generator,
1424+
num_inference_steps=2,
1425+
output_type="np",
1426+
init_image=init_image,
1427+
mask_image=mask_image,
1428+
).images
1429+
image_img2img = img2img(
1430+
[prompt],
1431+
generator=generator,
1432+
num_inference_steps=2,
1433+
output_type="np",
1434+
init_image=init_image,
1435+
).images
1436+
image_text2img = text2img(
1437+
[prompt],
1438+
generator=generator,
1439+
num_inference_steps=2,
1440+
output_type="np",
1441+
).images
1442+
1443+
assert image_inpaint.shape == (1, 32, 32, 3)
1444+
assert image_img2img.shape == (1, 32, 32, 3)
1445+
assert image_text2img.shape == (1, 128, 128, 3)
1446+
13941447

13951448
class PipelineTesterMixin(unittest.TestCase):
13961449
def tearDown(self):

0 commit comments

Comments
 (0)