Skip to content

Commit 639f645

Browse files
fix pipeline __setattr__ value == None (#3063)
* fix pipeline __setattr__ * add test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9d7c08f commit 639f645

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def __setattr__(self, name: str, value: Any):
510510
if hasattr(self, name) and hasattr(self.config, name):
511511
# We need to overwrite the config if name exists in config
512512
if isinstance(getattr(self.config, name), (tuple, list)):
513-
if self.config[name][0] is not None:
513+
if value is not None and self.config[name][0] is not None:
514514
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
515515
else:
516516
class_library_tuple = (None, None)

tests/test_pipelines.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,46 @@ def test_set_scheduler(self):
929929
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
930930
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
931931

932+
def test_set_component_to_none(self):
933+
unet = self.dummy_cond_unet()
934+
scheduler = PNDMScheduler(skip_prk_steps=True)
935+
vae = self.dummy_vae
936+
bert = self.dummy_text_encoder
937+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
938+
939+
pipeline = StableDiffusionPipeline(
940+
unet=unet,
941+
scheduler=scheduler,
942+
vae=vae,
943+
text_encoder=bert,
944+
tokenizer=tokenizer,
945+
safety_checker=None,
946+
feature_extractor=self.dummy_extractor,
947+
)
948+
949+
generator = torch.Generator(device="cpu").manual_seed(0)
950+
951+
prompt = "This is a flower"
952+
953+
out_image = pipeline(
954+
prompt=prompt,
955+
generator=generator,
956+
num_inference_steps=1,
957+
output_type="np",
958+
).images
959+
960+
pipeline.feature_extractor = None
961+
generator = torch.Generator(device="cpu").manual_seed(0)
962+
out_image_2 = pipeline(
963+
prompt=prompt,
964+
generator=generator,
965+
num_inference_steps=1,
966+
output_type="np",
967+
).images
968+
969+
assert out_image.shape == (1, 64, 64, 3)
970+
assert np.abs(out_image - out_image_2).max() < 1e-3
971+
932972
def test_set_scheduler_consistency(self):
933973
unet = self.dummy_cond_unet()
934974
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")

0 commit comments

Comments
 (0)