Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,22 @@ pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float16,
scheduler=lms,
)
```

or even easier you can make use of the `set_scheduler` functionality.

```python
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe.set_scheduler("lms_discrete")
```

Then you can run the pipeline just as before.

```
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/diffusion_pipeline.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrain
[[autodoc]] DiffusionPipeline
- from_pretrained
- save_pretrained
- set_scheduler
- to
- device
- components
- numpy_to_pil
- progress_bar
- set_progress_bar_config

## ImagePipelineOutput
By default diffusion pipelines return an object of class
Expand Down
12 changes: 6 additions & 6 deletions docs/source/api/pipelines/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ For more details about how Stable Diffusion works and how it differs from the ba

## Tips

### How to load and use different schedulers.
### How to use different schedulers.

The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], etc...
To use a different scheduler, you can pass make use of the [`DiffusionPipeline.set_scheduler`] function to the `scheduler` of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:

```python
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
from diffusers import StableDiffusionPipeline

euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipeline.set_scheduler("euler_discrete")
```


Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ The core API for any new scheduler must follow a limited structure.

The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.

### SchedulerType
[[autodoc]] schedulers.SchedulerType

### SchedulerMixin
[[autodoc]] SchedulerMixin

Expand Down
11 changes: 4 additions & 7 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,14 @@ Running the pipeline is then identical to the code above as it's the same model

Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
you could use it as follows:
use a different scheduler. *E.g.* if you would instead like to use the [`DPMSolverMultistepScheduler`] scheduler,
you could can just set the scheduler to `"dpm-multistep"`.

```python
>>> from diffusers import LMSDiscreteScheduler

>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

>>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... )
>>> generator = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> generator.set_scheduler("dpm-multistep")
```

[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
Expand Down
3 changes: 3 additions & 0 deletions docs/source/using-diffusers/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
```

**Note**: If you are often changing schedulers within the same script it is recommended to make use
of [`DiffusionPipeline.set_scheduler`] instead.

## API

[[autodoc]] modeling_utils.ModelMixin
Expand Down
28 changes: 22 additions & 6 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def register_to_config(self, **kwargs):
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__

# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
Expand Down Expand Up @@ -104,7 +103,9 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
logger.info(f"Configuration saved in {output_config_file}")

@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
def from_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike, dict], return_unused_kwargs=False, **kwargs
):
r"""
Instantiate a Python class from a pre-defined JSON-file.

Expand Down Expand Up @@ -163,15 +164,23 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret
</Tip>

"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
if not isinstance(pretrained_model_name_or_path, dict):
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
else:
config_dict = pretrained_model_name_or_path

init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config_dict, **kwargs)

# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")

# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)

# make sure to also save config parameters that might be used for compatible classes
model.register_to_config(**hidden_dict)

return_tuple = (model,)

# Flax schedulers have a state, so return it.
Expand Down Expand Up @@ -291,6 +300,9 @@ def _get_init_keys(cls):

@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# 0. Copy origin config dict
original_dict = {k: v for k, v in config_dict.items()}

# 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self")
Expand Down Expand Up @@ -364,7 +376,10 @@ def extract_init_dict(cls, config_dict, **kwargs):
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}

return init_dict, unused_kwargs
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}

return init_dict, unused_kwargs, hidden_config_dict

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
Expand Down Expand Up @@ -446,6 +461,8 @@ def register_to_config(init):
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
signature = inspect.signature(init)

init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
Expand All @@ -456,7 +473,6 @@ def inner_init(self, *args, **kwargs):
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self.mid_block = None
self.down_blocks = nn.ModuleList([])

# import ipdb; ipdb.set_trace()
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
Expand Down
84 changes: 80 additions & 4 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .schedulers import CLASS_TO_SCHEDULER_TYPE_MAPPING, SCHEDULER_TYPE_TO_CLASS_MAPPING, SchedulerType
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -207,7 +208,7 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self

module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
Expand All @@ -228,7 +229,7 @@ def device(self) -> torch.device:
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
Expand Down Expand Up @@ -513,7 +514,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}

init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
Expand Down Expand Up @@ -709,5 +710,80 @@ def progress_bar(self, iterable):

return tqdm(iterable, **self._progress_bar_config)

def set_scheduler(self, scheduler_type=Union[str, SchedulerType, Dict[str, str], Dict[str, SchedulerType]]):
r"""

Parameters:
scheduler_type (`str` or `Dict[str, str]`):
Can be either a string representing the type the scheduler should be set to or a mapping component name
to scheduler types in case the pipeline has multiple schedulers. Make sure to set the schedulers to one
of the officially supported scheduler types of [`schedulers.SchedulerType`].

Examples:

```py
>>> from diffusers import DiffusionPipeline

>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> pipe.set_scheduler("euler_discrete")
```
"""
schedulers = {k: type(v) for k, v in self.components.items() if isinstance(v, SchedulerMixin)}

if isinstance(scheduler_type, str) and len(set(schedulers.values())) > 1:
raise ValueError(
f"The pipeline {self} contains the schedulers {schedulers}. Please make sure to provide a dictionary"
f" that maps the componet names {schedulers.keys()} to scheduler types. Providing just one scheduler"
f" type {scheduler_type} is ambiguous."
)
elif isinstance(scheduler_type, dict):
is_type_scheduler = {k: k in schedulers for k in scheduler_type.keys()}
if not all(is_type_scheduler.values()):
raise ValueError(
"The following component names are not schedulers"
f" {[k for k, v in is_type_scheduler.items() if v == False]}. Please make sure to only set new"
f" scheduler types for {schedulers.keys()}."
)

scheduler_mapping = (
scheduler_type if isinstance(scheduler_type, dict) else {next(iter(schedulers.keys())): scheduler_type}
)

for component_name, scheduler_type in scheduler_mapping.items():
if isinstance(scheduler_type, SchedulerType):
scheduler_type = scheduler_type.name

scheduler_class = SCHEDULER_TYPE_TO_CLASS_MAPPING.get(scheduler_type, None)
current_scheduler = getattr(self, component_name)

if scheduler_class is None:
raise ValueError(
f"{scheduler_type} does not exist, make sure to chose a scheduler type from"
f" {', '.join(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys())}."
)

if scheduler_class.__name__ not in current_scheduler._compatible_classes and scheduler_class != type(
current_scheduler
):
diffusers_library = importlib.import_module(__name__.split(".")[0])
_compatible_class_types = [
CLASS_TO_SCHEDULER_TYPE_MAPPING[getattr(diffusers_library, c)]
for c in current_scheduler._compatible_classes
]
logger.warn(
f"Changing scheduler from type {CLASS_TO_SCHEDULER_TYPE_MAPPING[type(current_scheduler)]} to an"
f" uncompatible scheduler type {scheduler_type}. This is very likely going to lead to incorrect"
f" predictions when running the pipeline. Make sure to set {component_name} to a scheduler of type"
f" {[', '.join(_compatible_class_types)]}."
)

scheduler = scheduler_class.from_config(current_scheduler.config)

logger.info(
f"Changing scheduler from type {CLASS_TO_SCHEDULER_TYPE_MAPPING[type(current_scheduler)]} to"
f" {scheduler_type}."
)
setattr(self, component_name, scheduler)

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
31 changes: 30 additions & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from enum import Enum

from ..utils import is_flax_available, is_scipy_available, is_torch_available

Expand Down Expand Up @@ -50,3 +51,31 @@
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403


SCHEDULER_TYPE_TO_CLASS_MAPPING = OrderedDict(
[
("ddim", DDIMScheduler),
("ddpm", DDPMScheduler),
("dpm_multistep", DPMSolverMultistepScheduler),
("euler_ancestral_discrete", EulerAncestralDiscreteScheduler),
("euler_discrete", EulerDiscreteScheduler),
("ipndm", IPNDMScheduler),
("karras_ve", KarrasVeScheduler),
("pndm", PNDMScheduler),
("repaint", RePaintScheduler),
("score_sde_ve", ScoreSdeVeScheduler),
("score_sde_vp", ScoreSdeVpScheduler),
("vq_diffusion", VQDiffusionScheduler),
("lms_discrete", LMSDiscreteScheduler),
]
)
CLASS_TO_SCHEDULER_TYPE_MAPPING = OrderedDict({v: k for k, v in SCHEDULER_TYPE_TO_CLASS_MAPPING.items()})

SchedulerType = Enum("SchedulerType", list(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys()))
SchedulerType.__doc__ = (
"""Possible values for the `scheduler_type` argument in [`DiffusionPipeline.set_scheduler`]. Useful for tab-completion in
an IDE. Possible values are"""
+ "\n"
+ "\n- ".join(SCHEDULER_TYPE_TO_CLASS_MAPPING.keys())
)
8 changes: 3 additions & 5 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,8 @@ def test_stable_diffusion(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")

sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
sd_pipe.set_scheduler("ddim")
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

Expand All @@ -674,8 +673,7 @@ def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
pipe.set_scheduler("lms_discrete")

prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
Expand Down
Loading