diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 06912a1464eb..2e20c21aaf38 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -510,7 +510,7 @@ def __setattr__(self, name: str, value: Any): if hasattr(self, name) and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): - if self.config[name][0] is not None: + if value is not None and self.config[name][0] is not None: class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) else: class_library_tuple = (None, None) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 048030d98371..a5d70b01d453 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -929,6 +929,46 @@ def test_set_scheduler(self): sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + def test_set_component_to_none(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + pipeline = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + generator = torch.Generator(device="cpu").manual_seed(0) + + prompt = "This is a flower" + + out_image = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + pipeline.feature_extractor = None + generator = torch.Generator(device="cpu").manual_seed(0) + out_image_2 = pipeline( + prompt=prompt, + generator=generator, + num_inference_steps=1, + output_type="np", + ).images + + assert out_image.shape == (1, 64, 64, 3) + assert np.abs(out_image - out_image_2).max() < 1e-3 + def test_set_scheduler_consistency(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")