Skip to content

Commit 05f7b88

Browse files
[Stable Diffusion] Add components function
1 parent 100e094 commit 05f7b88

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import os
2020
from dataclasses import dataclass
21-
from typing import List, Optional, Union
21+
from typing import Any, Dict, List, Optional, Union
2222

2323
import numpy as np
2424
import torch
@@ -541,6 +541,41 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
541541
model = pipeline_class(**init_kwargs)
542542
return model
543543

544+
@property
545+
def components(self) -> Dict[str, Any]:
546+
r"""
547+
548+
The `self.compenents` property can be useful to run different pipelines with the same weights and
549+
configurations to not have to re-allocate memory.
550+
551+
Examples:
552+
553+
```py
554+
>>> from diffusers import (
555+
... StableDiffusionPipeline,
556+
... StableDiffusionImg2ImgPipeline,
557+
... StableDiffusionInpaintPipeline,
558+
... )
559+
560+
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
561+
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
562+
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
563+
```
564+
565+
Returns:
566+
A dictionaly containing all the modules needed to initialize the pipleline.
567+
"""
568+
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
569+
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
570+
571+
if set(components.keys()) != expected_modules:
572+
raise ValueError(
573+
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
574+
f" {expected_modules} to be defined, but {components} are defined."
575+
)
576+
577+
return components
578+
544579
@staticmethod
545580
def numpy_to_pil(images):
546581
"""

tests/test_pipelines.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,59 @@ def test_stable_diffusion_inpaint_fp16(self):
13141314

13151315
assert image.shape == (1, 32, 32, 3)
13161316

1317+
def test_components(self):
1318+
"""Test that components property works correctly"""
1319+
unet = self.dummy_cond_unet
1320+
scheduler = PNDMScheduler(skip_prk_steps=True)
1321+
vae = self.dummy_vae
1322+
bert = self.dummy_text_encoder
1323+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1324+
1325+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
1326+
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
1327+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
1328+
1329+
# make sure here that pndm scheduler skips prk
1330+
inpaint = StableDiffusionInpaintPipeline(
1331+
unet=unet,
1332+
scheduler=scheduler,
1333+
vae=vae,
1334+
text_encoder=bert,
1335+
tokenizer=tokenizer,
1336+
safety_checker=self.dummy_safety_checker,
1337+
feature_extractor=self.dummy_extractor,
1338+
)
1339+
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components)
1340+
text2img = StableDiffusionPipeline(**inpaint.components)
1341+
1342+
prompt = "A painting of a squirrel eating a burger"
1343+
generator = torch.Generator(device=torch_device).manual_seed(0)
1344+
image_inpaint = inpaint(
1345+
[prompt],
1346+
generator=generator,
1347+
num_inference_steps=2,
1348+
output_type="np",
1349+
init_image=init_image,
1350+
mask_image=mask_image,
1351+
).images
1352+
image_img2img = img2img(
1353+
[prompt],
1354+
generator=generator,
1355+
num_inference_steps=2,
1356+
output_type="np",
1357+
init_image=init_image,
1358+
).images
1359+
image_text2img = text2img(
1360+
[prompt],
1361+
generator=generator,
1362+
num_inference_steps=2,
1363+
output_type="np",
1364+
).images
1365+
1366+
assert image_inpaint.shape == (1, 32, 32, 3)
1367+
assert image_img2img.shape == (1, 32, 32, 3)
1368+
assert image_text2img.shape == (1, 128, 128, 3)
1369+
13171370

13181371
class PipelineTesterMixin(unittest.TestCase):
13191372
def tearDown(self):

0 commit comments

Comments
 (0)