diff --git a/README.md b/README.md index 64cbd15aab26..8c6fd86326ac 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,22 @@ pipe = StableDiffusionPipeline.from_pretrained( torch_dtype=torch.float16, scheduler=lms, ) +``` + +or even easier you can make use of the `set_scheduler` functionality. + +```python +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe.set_scheduler("lms_discrete") +``` + +Then you can run the pipeline just as before. + +``` pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" diff --git a/docs/source/api/diffusion_pipeline.mdx b/docs/source/api/diffusion_pipeline.mdx index b037b4e26dc1..e9fe1c126163 100644 --- a/docs/source/api/diffusion_pipeline.mdx +++ b/docs/source/api/diffusion_pipeline.mdx @@ -32,9 +32,13 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain [[autodoc]] DiffusionPipeline - from_pretrained - save_pretrained + - set_scheduler - to - device - components + - numpy_to_pil + - progress_bar + - set_progress_bar_config ## ImagePipelineOutput By default diffusion pipelines return an object of class diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 26d6a210adad..40226eb55520 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -31,16 +31,16 @@ For more details about how Stable Diffusion works and how it differs from the ba ## Tips -### How to load and use different schedulers. +### How to use different schedulers. -The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. -To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: +The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], etc... +To use a different scheduler, you can pass make use of the [`DiffusionPipeline.set_scheduler`] function to the `scheduler` of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: ```python -from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler +from diffusers import StableDiffusionPipeline -euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") -pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) +pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipeline.set_scheduler("euler_discrete") ``` diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 7ed527bedf3f..f2cc4925b382 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -48,6 +48,9 @@ The core API for any new scheduler must follow a limited structure. The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. +### SchedulerType +[[autodoc]] schedulers.SchedulerType + ### SchedulerMixin [[autodoc]] SchedulerMixin diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 463780a0727f..825e954700cb 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -115,17 +115,14 @@ Running the pipeline is then identical to the code above as it's the same model Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to -use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler, -you could use it as follows: +use a different scheduler. *E.g.* if you would instead like to use the [`DPMSolverMultistepScheduler`] scheduler, +you could can just set the scheduler to `"dpm-multistep"`. ```python >>> from diffusers import LMSDiscreteScheduler ->>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") - ->>> generator = StableDiffusionPipeline.from_pretrained( -... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN -... ) +>>> generator = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) +>>> generator.set_scheduler("dpm-multistep") ``` [Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx index 2cb980ea618a..39a5e87c44d5 100644 --- a/docs/source/using-diffusers/loading.mdx +++ b/docs/source/using-diffusers/loading.mdx @@ -379,6 +379,9 @@ dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler") pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) ``` +**Note**: If you are often changing schedulers within the same script it is recommended to make use +of [`DiffusionPipeline.set_scheduler`] instead. + ## API [[autodoc]] modeling_utils.ModelMixin diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fc6ac9b5b97e..29ee0e8aa7a7 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -63,7 +63,6 @@ def register_to_config(self, **kwargs): kwargs["_class_name"] = self.__class__.__name__ kwargs["_diffusers_version"] = __version__ - # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) @@ -104,7 +103,9 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + def from_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike, dict], return_unused_kwargs=False, **kwargs + ): r""" Instantiate a Python class from a pre-defined JSON-file. @@ -163,8 +164,12 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret """ - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + if not isinstance(pretrained_model_name_or_path, dict): + config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + else: + config_dict = pretrained_model_name_or_path + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config_dict, **kwargs) # Allow dtype to be specified on initialization if "dtype" in unused_kwargs: @@ -172,6 +177,10 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + return_tuple = (model,) # Flax schedulers have a state, so return it. @@ -291,6 +300,9 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") @@ -364,7 +376,10 @@ def extract_init_dict(cls, config_dict, **kwargs): # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} - return init_dict, unused_kwargs + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")} + + return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -446,6 +461,8 @@ def register_to_config(init): def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + signature = inspect.signature(init) + init(self, *args, **init_kwargs) if not isinstance(self, ConfigMixin): raise RuntimeError( @@ -456,7 +473,6 @@ def inner_init(self, *args, **kwargs): ignore = getattr(self, "ignore_for_config", []) # Get positional arguments aligned with kwargs new_kwargs = {} - signature = inspect.signature(init) parameters = { name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore } diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 30de343d08ee..dda83f29445e 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -84,6 +84,7 @@ def __init__( self.mid_block = None self.down_blocks = nn.ModuleList([]) + # import ipdb; ipdb.set_trace() # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a194f3eb34d9..4d5b07a4d1a6 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -34,7 +34,8 @@ from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import http_user_agent from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .schedulers import CLASS_TO_SCHEDULER_TYPE_MAPPING, SCHEDULER_TYPE_TO_CLASS_MAPPING, SchedulerType +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -207,7 +208,7 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): if torch_device is None: return self - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -228,7 +229,7 @@ def device(self) -> torch.device: Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -513,7 +514,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) if len(unused_kwargs) > 0: logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") @@ -709,5 +710,80 @@ def progress_bar(self, iterable): return tqdm(iterable, **self._progress_bar_config) + def set_scheduler(self, scheduler_type=Union[str, SchedulerType, Dict[str, str], Dict[str, SchedulerType]]): + r""" + + Parameters: + scheduler_type (`str` or `Dict[str, str]`): + Can be either a string representing the type the scheduler should be set to or a mapping component name + to scheduler types in case the pipeline has multiple schedulers. Make sure to set the schedulers to one + of the officially supported scheduler types of [`schedulers.SchedulerType`]. + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> pipe.set_scheduler("euler_discrete") + ``` + """ + schedulers = {k: type(v) for k, v in self.components.items() if isinstance(v, SchedulerMixin)} + + if isinstance(scheduler_type, str) and len(set(schedulers.values())) > 1: + raise ValueError( + f"The pipeline {self} contains the schedulers {schedulers}. Please make sure to provide a dictionary" + f" that maps the componet names {schedulers.keys()} to scheduler types. Providing just one scheduler" + f" type {scheduler_type} is ambiguous." + ) + elif isinstance(scheduler_type, dict): + is_type_scheduler = {k: k in schedulers for k in scheduler_type.keys()} + if not all(is_type_scheduler.values()): + raise ValueError( + "The following component names are not schedulers" + f" {[k for k, v in is_type_scheduler.items() if v == False]}. Please make sure to only set new" + f" scheduler types for {schedulers.keys()}." + ) + + scheduler_mapping = ( + scheduler_type if isinstance(scheduler_type, dict) else {next(iter(schedulers.keys())): scheduler_type} + ) + + for component_name, scheduler_type in scheduler_mapping.items(): + if isinstance(scheduler_type, SchedulerType): + scheduler_type = scheduler_type.name + + scheduler_class = SCHEDULER_TYPE_TO_CLASS_MAPPING.get(scheduler_type, None) + current_scheduler = getattr(self, component_name) + + if scheduler_class is None: + raise ValueError( + f"{scheduler_type} does not exist, make sure to chose a scheduler type from" + f" {', '.join(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys())}." + ) + + if scheduler_class.__name__ not in current_scheduler._compatible_classes and scheduler_class != type( + current_scheduler + ): + diffusers_library = importlib.import_module(__name__.split(".")[0]) + _compatible_class_types = [ + CLASS_TO_SCHEDULER_TYPE_MAPPING[getattr(diffusers_library, c)] + for c in current_scheduler._compatible_classes + ] + logger.warn( + f"Changing scheduler from type {CLASS_TO_SCHEDULER_TYPE_MAPPING[type(current_scheduler)]} to an" + f" uncompatible scheduler type {scheduler_type}. This is very likely going to lead to incorrect" + f" predictions when running the pipeline. Make sure to set {component_name} to a scheduler of type" + f" {[', '.join(_compatible_class_types)]}." + ) + + scheduler = scheduler_class.from_config(current_scheduler.config) + + logger.info( + f"Changing scheduler from type {CLASS_TO_SCHEDULER_TYPE_MAPPING[type(current_scheduler)]} to" + f" {scheduler_type}." + ) + setattr(self, component_name, scheduler) + def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6217bfcd6985..c773f2552003 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import OrderedDict +from enum import Enum from ..utils import is_flax_available, is_scipy_available, is_torch_available @@ -50,3 +51,31 @@ from .scheduling_lms_discrete import LMSDiscreteScheduler else: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 + + +SCHEDULER_TYPE_TO_CLASS_MAPPING = OrderedDict( + [ + ("ddim", DDIMScheduler), + ("ddpm", DDPMScheduler), + ("dpm_multistep", DPMSolverMultistepScheduler), + ("euler_ancestral_discrete", EulerAncestralDiscreteScheduler), + ("euler_discrete", EulerDiscreteScheduler), + ("ipndm", IPNDMScheduler), + ("karras_ve", KarrasVeScheduler), + ("pndm", PNDMScheduler), + ("repaint", RePaintScheduler), + ("score_sde_ve", ScoreSdeVeScheduler), + ("score_sde_vp", ScoreSdeVpScheduler), + ("vq_diffusion", VQDiffusionScheduler), + ("lms_discrete", LMSDiscreteScheduler), + ] +) +CLASS_TO_SCHEDULER_TYPE_MAPPING = OrderedDict({v: k for k, v in SCHEDULER_TYPE_TO_CLASS_MAPPING.items()}) + +SchedulerType = Enum("SchedulerType", list(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys())) +SchedulerType.__doc__ = ( + """Possible values for the `scheduler_type` argument in [`DiffusionPipeline.set_scheduler`]. Useful for tab-completion in +an IDE. Possible values are""" + + "\n" + + "\n- ".join(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys()) +) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6e1071124cb7..1dfa7cfad3ba 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -651,9 +651,8 @@ def test_stable_diffusion(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_fast_ddim(self): - scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler") - - sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler) + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") + sd_pipe.set_scheduler("ddim") sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) @@ -674,8 +673,7 @@ def test_lms_stable_diffusion_pipeline(self): model_id = "CompVis/stable-diffusion-v1-1" pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) pipe.set_progress_bar_config(disable=None) - scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") - pipe.scheduler = scheduler + pipe.set_scheduler("lms_discrete") prompt = "a photograph of an astronaut riding a horse" generator = torch.Generator(device=torch_device).manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 6d5c6feab5bc..a1ab8a7db58c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -519,12 +519,11 @@ def test_stable_diffusion_img2img_pipeline_k_lms(self): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, - scheduler=lms, safety_checker=None, ) + pipe.set_scheduler("lms_discrete") pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() @@ -612,10 +611,10 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): init_image = init_image.resize((768, 512)) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 + model_id, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 ) + pipe.set_scheduler("lms_discrete") pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing(1) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 5fcdd71dd6e4..99b7d5812d62 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -359,8 +359,8 @@ def test_stable_diffusion_inpaint_pipeline_pndm(self): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") - pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.set_scheduler("pndm") pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() @@ -396,15 +396,14 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, safety_checker=None, - scheduler=pndm, device_map="auto", revision="fp16", torch_dtype=torch.float16, ) + pipe.set_scheduler("pndm") pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing(1) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index c5b2572fb79a..aed5931cf397 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -22,7 +22,6 @@ from diffusers import ( AutoencoderKL, - LMSDiscreteScheduler, PNDMScheduler, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, @@ -402,12 +401,8 @@ def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") - pipe = StableDiffusionInpaintPipeline.from_pretrained( - model_id, - scheduler=lms, - safety_checker=None, - ) + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.set_scheduler("lms_discrete") pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4559d713ed81..b40605dba69a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -22,6 +22,7 @@ import numpy as np import torch +import diffusers import PIL from diffusers import ( AutoencoderKL, @@ -29,6 +30,10 @@ DDIMScheduler, DDPMPipeline, DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, @@ -39,6 +44,7 @@ logging, ) from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import SchedulerType from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu @@ -398,6 +404,114 @@ def test_components(self): assert image_img2img.shape == (1, 32, 32, 3) assert image_text2img.shape == (1, 128, 128, 3) + def test_set_scheduler(self): + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.set_scheduler("ddim") + assert isinstance(sd.scheduler, DDIMScheduler) + sd.set_scheduler("ddpm") + assert isinstance(sd.scheduler, DDPMScheduler) + sd.set_scheduler("pndm") + assert isinstance(sd.scheduler, PNDMScheduler) + sd.set_scheduler("lms_discrete") + assert isinstance(sd.scheduler, LMSDiscreteScheduler) + sd.set_scheduler("euler_discrete") + assert isinstance(sd.scheduler, EulerDiscreteScheduler) + sd.set_scheduler("euler_ancestral_discrete") + assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler) + sd.set_scheduler("dpm_multistep") + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + sd.set_scheduler(SchedulerType.dpm_multistep) + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + sd.set_scheduler("dpm_multistep") + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + sd.set_scheduler({"scheduler": "dpm_multistep"}) + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + + logger = logging.get_logger("diffusers.pipeline_utils") + with self.assertRaises(ValueError) as error_1: + sd.set_scheduler({"schedule": "dpm_multistep"}) + + with self.assertRaises(ValueError) as error_2: + sd.set_scheduler({"scheduler": "dpm_multiste"}) + + logger.setLevel(diffusers.logging.INFO) + with CaptureLogger(logger) as cap_logger: + sd.set_scheduler({"scheduler": "dpm_multistep"}) + + with CaptureLogger(logger) as cap_logger_warn: + sd.set_scheduler({"scheduler": "ipndm"}) + + assert ( + str(error_1.exception) + == "The following component names are not schedulers ['schedule']. Please make sure to only set new" + " scheduler types for dict_keys(['scheduler'])." + ) + assert "dpm_multiste does not exist, make sure to chose a scheduler type from" in str(error_2.exception) + assert cap_logger.out == "Changing scheduler from type dpm_multistep to dpm_multistep.\n" + assert ( + "Changing scheduler from type dpm_multistep to an uncompatible scheduler type ipndm." + in cap_logger_warn.out + ) + + def test_set_scheduler_consitency(self): + unet = self.dummy_cond_unet + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=pndm, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + pndm_config = sd.scheduler.config + sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + pndm_config_2 = sd.scheduler.config + pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config} + + assert dict(pndm_config) == dict(pndm_config_2) + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=ddim, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + ddim_config = sd.scheduler.config + sd.set_scheduler("lms_discrete") + sd.set_scheduler("ddim") + ddim_config_2 = sd.scheduler.config + ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config} + + assert dict(ddim_config) == dict(ddim_config_2) + @slow class PipelineSlowTests(unittest.TestCase):