From 0ab048ce6aeb334489cf988999ea8dac7f58b2cd Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 17 Oct 2022 21:52:39 +0200 Subject: [PATCH] Rename and deprecate --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/stable_diffusion/__init__.py | 2 +- ...x.py => pipeline_onnx_stable_diffusion.py} | 28 +++++++++++++++++-- ...torch_and_transformers_and_onnx_objects.py | 15 ++++++++++ tests/test_pipelines.py | 6 ++-- 6 files changed, 47 insertions(+), 8 deletions(-) rename src/diffusers/pipelines/stable_diffusion/{pipeline_stable_diffusion_onnx.py => pipeline_onnx_stable_diffusion.py} (89%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0a0b0b4965dd..0041030c2666 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -58,7 +58,7 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 if is_torch_available() and is_transformers_available() and is_onnx_available(): - from .pipelines import StableDiffusionOnnxPipeline + from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1c31595fb0cf..01391b0db362 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -20,7 +20,7 @@ ) if is_transformers_available() and is_onnx_available(): - from .stable_diffusion import StableDiffusionOnnxPipeline + from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8c07afe58fc2..b1f3240f3ee3 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -34,7 +34,7 @@ class StableDiffusionPipelineOutput(BaseOutput): from .safety_checker import StableDiffusionSafetyChecker if is_transformers_available() and is_onnx_available(): - from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): import flax diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py similarity index 89% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py rename to src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index acd6446af6fe..91bf5012362e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -8,14 +8,14 @@ from ...onnx_utils import OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) -class StableDiffusionOnnxPipeline(DiffusionPipeline): +class OnnxStableDiffusionPipeline(DiffusionPipeline): vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer @@ -198,3 +198,27 @@ def __call__( return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index d099b837296a..72ca97f51450 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -4,6 +4,21 @@ from ..utils import DummyObject, requires_backends +class OnnxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 69a45c4247aa..659e69554b94 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -37,13 +37,13 @@ LDMPipeline, LDMTextToImagePipeline, LMSDiscreteScheduler, + OnnxStableDiffusionPipeline, PNDMPipeline, PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, - StableDiffusionOnnxPipeline, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, @@ -2010,7 +2010,7 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self): @slow def test_stable_diffusion_onnx(self): - sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( + sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) @@ -2214,7 +2214,7 @@ def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: test_callback_fn.has_been_called = False - pipe = StableDiffusionOnnxPipeline.from_pretrained( + pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None)