Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
)
else:
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
deprecate,
is_transformers_available,
logging,
)
Expand Down Expand Up @@ -413,6 +414,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

# To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback mechanism

version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy

pipeline_class = StableDiffusionInpaintPipelineLegacy

deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)

# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .safety_checker import StableDiffusionSafetyChecker

if is_transformers_available() and is_onnx_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def __init__(
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")

if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
Expand Down Expand Up @@ -223,6 +221,8 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# TODO(Suraj) - adapt to your use case

if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
Expand Down
Loading