From 6db4338038ed180a047b30de12f21a388444e779 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 22:06:05 +0200 Subject: [PATCH 01/16] [Utils] Add deprecate function --- examples/conftest.py | 4 +-- examples/test_examples.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 9 ++---- src/diffusers/utils/outputs.py | 8 ++---- src/diffusers/{ => utils}/testing_utils.py | 31 +++++++++++++++++++++ tests/conftest.py | 4 +-- tests/test_layers_utils.py | 2 +- tests/test_modeling_common.py | 2 +- tests/test_models_unet.py | 2 +- tests/test_models_vae.py | 2 +- tests/test_models_vq.py | 2 +- tests/test_pipelines.py | 2 +- tests/test_training.py | 2 +- 13 files changed, 47 insertions(+), 25 deletions(-) rename src/diffusers/{ => utils}/testing_utils.py (87%) diff --git a/examples/conftest.py b/examples/conftest.py index a72bc85310d2..903e1a8758f8 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -32,13 +32,13 @@ def pytest_addoption(parser): - from diffusers.testing_utils import pytest_addoption_shared + from diffusers.utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.testing_utils import pytest_terminal_summary_main + from diffusers.utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: diff --git a/examples/test_examples.py b/examples/test_examples.py index 0099d17e638d..8838713cb7d0 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -24,7 +24,7 @@ from typing import List from accelerate.utils import write_basic_config -from diffusers.testing_utils import slow +from diffusers.utils import slow logging.basicConfig(level=logging.DEBUG) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc17cee4c810..d28fe96831fe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -24,6 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput +from ..testing_utils import deprecate_args from .scheduling_utils import SchedulerMixin @@ -115,12 +115,7 @@ def __init__( clip_sample: bool = True, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate_args(kwargs, "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.") if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 45d483ce7b1d..54d6b472368e 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -15,7 +15,6 @@ Generic utilities """ -import warnings from collections import OrderedDict from dataclasses import fields from typing import Any, Tuple @@ -23,6 +22,7 @@ import numpy as np from .import_utils import is_torch_available +from .testing_utils import deprecate_args def is_tensor(x): @@ -87,11 +87,7 @@ def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": - warnings.warn( - "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" - " `'images'` instead.", - DeprecationWarning, - ) + deprecate_args("samples", "0.4.0", "Please use `.images` or `'images'` instead.") return inner_dict["images"] return inner_dict[k] else: diff --git a/src/diffusers/testing_utils.py b/src/diffusers/utils/testing_utils.py similarity index 87% rename from src/diffusers/testing_utils.py rename to src/diffusers/utils/testing_utils.py index d3f6fa628d9d..50a653692531 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,6 +1,7 @@ import os import random import re +import inspect import unittest from distutils.util import strtobool from pathlib import Path @@ -11,7 +12,9 @@ import PIL.Image import PIL.ImageOps import requests +import warnings from packaging import version +from . import __version__ global_rng = random.Random() @@ -22,6 +25,34 @@ torch_device = "mps" if torch.backends.mps.is_available() else torch_device +def deprecate_args(*args, deprecated_kwargs=None): + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError(f"The deprecation tuple {(attribute, version, message)} should be removed since diffusers' version is >= {version}") + + if attribute in deprecated_kwargs or deprecated_kwargs is None: + values += (deprecated_kwargs.pop(attribute),) + + warnings.warn( + f"The `{attribute}` is deprecated as an argument and will be removed in version {version}. {message}.", + DeprecationWarning, + ) + + if deprecated_kwargs is not None and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + return values + + def parse_flag_from_env(key, default=False): try: value = os.environ[key] diff --git a/tests/conftest.py b/tests/conftest.py index e116f40e6461..5547c67988ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,13 +31,13 @@ def pytest_addoption(parser): - from diffusers.testing_utils import pytest_addoption_shared + from diffusers.utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.testing_utils import pytest_terminal_summary_main + from diffusers.utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 4c9b17caa74c..f6cb184651ef 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D -from diffusers.testing_utils import torch_device +from diffusers.utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b0d00b863a78..19cd5dc8b725 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,7 +22,7 @@ import torch from diffusers.modeling_utils import ModelMixin -from diffusers.testing_utils import torch_device +from diffusers.utils import torch_device from diffusers.training_utils import EMAModel diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 80055c1a10f8..d7610c7bbba1 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -19,7 +19,7 @@ import torch from diffusers import UNet2DConditionModel, UNet2DModel -from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 361eb618ab22..9fb7e8ea3bb7 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -19,7 +19,7 @@ from diffusers import AutoencoderKL from diffusers.modeling_utils import ModelMixin -from diffusers.testing_utils import floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index 7cce0ed13e01..9a2094d46cb4 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -18,7 +18,7 @@ import torch from diffusers import VQModel -from diffusers.testing_utils import floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 61d5ac3a4e28..87cf5826deb5 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -48,7 +48,7 @@ ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer diff --git a/tests/test_training.py b/tests/test_training.py index 41aae07e33c6..de11b8c061ca 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -18,7 +18,7 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.testing_utils import slow +from diffusers.utils import slow from diffusers.training_utils import set_seed From 8447d8834ddb3980c0d6afbf31a92b0d29c7de28 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 22:12:37 +0200 Subject: [PATCH 02/16] up --- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/testing_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b63dbd2b285c..65d58f93fcfd 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -35,6 +35,7 @@ ) from .logging import get_logger from .outputs import BaseOutput +from .testing_utils import deprecate_args logger = get_logger(__name__) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 50a653692531..77c590882ad7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -14,7 +14,7 @@ import requests import warnings from packaging import version -from . import __version__ +from .. import __version__ global_rng = random.Random() From 5cde36b72fdc393dc80f3c9684f4fa223e5ffe74 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 22:43:54 +0200 Subject: [PATCH 03/16] up --- src/diffusers/dependency_versions_table.py | 1 - .../pipeline_stable_diffusion.py | 9 ++--- .../pipeline_stable_diffusion_img2img.py | 9 ++--- .../pipeline_stable_diffusion_inpaint.py | 9 ++--- src/diffusers/schedulers/scheduling_ddim.py | 28 +++++-------- src/diffusers/schedulers/scheduling_ddpm.py | 10 +++-- .../schedulers/scheduling_karras_ve.py | 15 ++++--- .../schedulers/scheduling_lms_discrete.py | 15 ++++--- src/diffusers/schedulers/scheduling_pndm.py | 26 +++++------- src/diffusers/schedulers/scheduling_sde_ve.py | 21 ++++------ src/diffusers/schedulers/scheduling_sde_vp.py | 14 +++---- src/diffusers/schedulers/scheduling_utils.py | 13 +++--- src/diffusers/utils/testing_utils.py | 40 +++++++++++++------ tests/test_modeling_common.py | 2 +- tests/test_pipelines.py | 3 +- tests/test_training.py | 2 +- 16 files changed, 102 insertions(+), 115 deletions(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 09a7baad560d..82ca5dbb6f56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -17,7 +17,6 @@ "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 77f25ef1b9c5..3964df9644e2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import List, Optional, Union import torch @@ -10,7 +9,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -59,15 +58,15 @@ def __init__( super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2ccee71c024..2a9a1c91ae76 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import List, Optional, Union import numpy as np @@ -12,7 +11,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -71,15 +70,15 @@ def __init__( super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a95f9152279a..e3093ebe7822 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import List, Optional, Union import numpy as np @@ -13,7 +12,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -86,15 +85,15 @@ def __init__( logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0d9e285e054e..1d161e25b84d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,7 +16,6 @@ # and https://github.com/hojonathanho/diffusion import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -24,7 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -122,12 +121,12 @@ def __init__( steps_offset: int = 0, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -175,17 +174,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - - offset = self.config.steps_offset - - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead.", - DeprecationWarning, - ) - - offset = kwargs["offset"] + deprecated_offset = deprecate("offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", kwargs) + offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d28fe96831fe..a54090321e7d 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,8 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from ..testing_utils import deprecate_args +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -115,7 +114,12 @@ def __init__( clip_sample: bool = True, **kwargs, ): - deprecate_args(kwargs, "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.") + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index e6e5300e73e7..f425c0d3fb30 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -13,7 +13,6 @@ # limitations under the License. -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -21,7 +20,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -89,12 +88,12 @@ def __init__( s_max: float = 50, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) # setable values self.num_inference_steps: int = None diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6d8db7682db5..16d1e6eb5f37 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -22,7 +21,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -77,12 +76,12 @@ def __init__( trained_betas: Optional[np.ndarray] = None, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d9e430c4a656..da45a0fca935 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,13 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import warnings from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -102,12 +102,12 @@ def __init__( steps_offset: int = 0, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -155,16 +155,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - - offset = self.config.steps_offset - - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead." - ) - - offset = kwargs["offset"] + deprecated_offset = deprecate("offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", kwargs) + offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index a549654c3b6f..ed0fc35b7c0a 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,14 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -78,12 +77,12 @@ def __init__( correct_steps: int = 1, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) # setable values self.timesteps = None @@ -139,11 +138,7 @@ def get_adjacent_sigma(self, timesteps, t): ) def set_seed(self, seed): - warnings.warn( - "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" - " generator instead.", - DeprecationWarning, - ) + deprecate("set_seed", "0.4.0", "Please consider passing a generator instead.") torch.manual_seed(seed) def step_pred( diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index daea743873f1..fbf12716d48a 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -17,11 +17,11 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit import math -import warnings import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import SchedulerMixin @@ -42,12 +42,12 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + kwargs, + ) self.sigmas = None self.discrete_sigmas = None self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414a6..aba295bc8039 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass import torch -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -44,10 +43,10 @@ class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME def set_format(self, tensor_format="pt"): - warnings.warn( - "The method `set_format` is deprecated and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this function as the schedulers" - "are always in Pytorch", - DeprecationWarning, + deprecate( + "set_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this function as the schedulers are always" + " in Pytorch", ) return self diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 77c590882ad7..949e2783af1c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,19 +1,20 @@ +import inspect import os import random import re -import inspect import unittest +import warnings from distutils.util import strtobool from pathlib import Path -from typing import Union +from typing import Any, Dict, Optional, Union import torch import PIL.Image import PIL.ImageOps import requests -import warnings from packaging import version + from .. import __version__ @@ -25,24 +26,33 @@ torch_device = "mps" if torch.backends.mps.is_available() else torch_device -def deprecate_args(*args, deprecated_kwargs=None): +def deprecate(*args, deprecated_kwargs=Optional[Union[Dict, Any]], standard_warn=True): values = () if not isinstance(args[0], tuple): args = (args,) for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): - raise ValueError(f"The deprecation tuple {(attribute, version, message)} should be removed since diffusers' version is >= {version}") - - if attribute in deprecated_kwargs or deprecated_kwargs is None: - values += (deprecated_kwargs.pop(attribute),) - - warnings.warn( - f"The `{attribute}` is deprecated as an argument and will be removed in version {version}. {message}.", - DeprecationWarning, + raise ValueError( + f"The deprecation tuple {(attribute, version, message)} should be removed since diffusers' version is" + f" >= {version}" ) - if deprecated_kwargs is not None and len(deprecated_kwargs) > 0: + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version}." + + if warning is not None: + warning = warning if standard_warn else "" + warnings.warn(warning + message, DeprecationWarning) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] filename = call_frame.filename line_number = call_frame.lineno @@ -50,6 +60,10 @@ def deprecate_args(*args, deprecated_kwargs=None): key, value = next(iter(deprecated_kwargs.items())) raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + if len(values) == 0: + return + elif len(values) == 1: + return values[0] return values diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 19cd5dc8b725..e4e546e55ac3 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,8 +22,8 @@ import torch from diffusers.modeling_utils import ModelMixin -from diffusers.utils import torch_device from diffusers.training_utils import EMAModel +from diffusers.utils import torch_device class ModelTesterMixin: diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 87cf5826deb5..975f6bafb5f4 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -48,8 +48,7 @@ ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer diff --git a/tests/test_training.py b/tests/test_training.py index de11b8c061ca..d7b7c94155c6 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -18,8 +18,8 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.utils import slow from diffusers.training_utils import set_seed +from diffusers.utils import slow torch.backends.cuda.matmul.allow_tf32 = False From 2d80a0835d5a3446ff4533dfec0a370330bb63b8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 27 Sep 2022 22:52:17 +0200 Subject: [PATCH 04/16] uP --- src/diffusers/schedulers/scheduling_ddim.py | 6 ++++-- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_karras_ve.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 6 ++++-- src/diffusers/schedulers/scheduling_sde_vp.py | 2 +- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/outputs.py | 4 ++-- src/diffusers/utils/testing_utils.py | 7 ++++--- 8 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 1d161e25b84d..7a31abb55837 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -125,7 +125,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) if trained_betas is not None: @@ -174,7 +174,9 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - deprecated_offset = deprecate("offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", kwargs) + deprecated_offset = deprecate( + "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs + ) offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index a54090321e7d..4d4e986a76ea 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -118,7 +118,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) if trained_betas is not None: diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index f425c0d3fb30..63e1400262d8 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -92,7 +92,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) # setable values diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index da45a0fca935..0ee914ca01f2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -106,7 +106,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) if trained_betas is not None: @@ -155,7 +155,9 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - deprecated_offset = deprecate("offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", kwargs) + deprecated_offset = deprecate( + "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs + ) offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index fbf12716d48a..7cf1da44272a 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -46,7 +46,7 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) self.sigmas = None self.discrete_sigmas = None diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 65d58f93fcfd..b0901bca8209 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -35,7 +35,7 @@ ) from .logging import get_logger from .outputs import BaseOutput -from .testing_utils import deprecate_args +from .testing_utils import deprecate logger = get_logger(__name__) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 54d6b472368e..414f187d17e8 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -22,7 +22,7 @@ import numpy as np from .import_utils import is_torch_available -from .testing_utils import deprecate_args +from .testing_utils import deprecate def is_tensor(x): @@ -87,7 +87,7 @@ def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": - deprecate_args("samples", "0.4.0", "Please use `.images` or `'images'` instead.") + deprecate("samples", "0.4.0", "Please use `.images` or `'images'` instead.") return inner_dict["images"] return inner_dict[k] else: diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 949e2783af1c..64925253b398 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -15,8 +15,6 @@ import requests from packaging import version -from .. import __version__ - global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -26,7 +24,10 @@ torch_device = "mps" if torch.backends.mps.is_available() else torch_device -def deprecate(*args, deprecated_kwargs=Optional[Union[Dict, Any]], standard_warn=True): +def deprecate(*args, take_from=Optional[Union[Dict, Any]], standard_warn=True): + from .. import __version__ + + deprecated_kwargs = take_from values = () if not isinstance(args[0], tuple): args = (args,) From e1af3ab32dee0c4f26f0dabc45ce80a27e2d1a19 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 28 Sep 2022 10:45:30 +0200 Subject: [PATCH 05/16] up --- examples/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/conftest.py b/examples/conftest.py index 903e1a8758f8..d2f9600313a1 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -32,13 +32,13 @@ def pytest_addoption(parser): - from diffusers.utils import pytest_addoption_shared + from diffusers.utils.testing_utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.utils import pytest_terminal_summary_main + from diffusers.utils.testing_utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: From 04909f8cbe3a099e6020fcdbd30e15c203a10dae Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 17:56:05 +0200 Subject: [PATCH 06/16] up --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5547c67988ec..3cfab533e43c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,13 +31,13 @@ def pytest_addoption(parser): - from diffusers.utils import pytest_addoption_shared + from diffusers.utils.testing_utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.utils import pytest_terminal_summary_main + from diffusers.utils.testing_utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: From e51a656fab8935950c13fc21a9a04ed6b396d170 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 18:39:07 +0200 Subject: [PATCH 07/16] up --- src/diffusers/utils/testing_utils.py | 14 +-- tests/test_utils.py | 151 +++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 7 deletions(-) create mode 100755 tests/test_utils.py diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 64925253b398..8154f90224e1 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -24,7 +24,7 @@ torch_device = "mps" if torch.backends.mps.is_available() else torch_device -def deprecate(*args, take_from=Optional[Union[Dict, Any]], standard_warn=True): +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): from .. import __version__ deprecated_kwargs = take_from @@ -35,22 +35,22 @@ def deprecate(*args, take_from=Optional[Union[Dict, Any]], standard_warn=True): for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( - f"The deprecation tuple {(attribute, version, message)} should be removed since diffusers' version is" - f" >= {version}" + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" ) warning = None if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: values += (deprecated_kwargs.pop(attribute),) - warning = f"The `{attribute}` argument is deprecated and will be removed in version {version}." + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." elif hasattr(deprecated_kwargs, attribute): values += (getattr(deprecated_kwargs, attribute),) - warning = f"The `{attribute}` argument is deprecated and will be removed in version {version}." + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." elif deprecated_kwargs is None: - warning = f"`{attribute}` is deprecated and will be removed in version {version}." + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." if warning is not None: - warning = warning if standard_warn else "" + warning = warning + " " if standard_warn else "" warnings.warn(warning + message, DeprecationWarning) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100755 index 000000000000..37413ee827fd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import __version__ +from diffusers.utils import deprecate + + +class DeprecateTester(unittest.TestCase): + + higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) + lower_version = "0.0.1" + + def test_deprecate_function_arg(self): + kwargs = {"deprecated_arg": 4} + + with self.assertWarns(DeprecationWarning) as warning: + output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) + + assert output == 4 + assert ( + str(warning.warning) + == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." + " message" + ) + + def test_deprecate_function_args(self): + kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} + with self.assertWarns(DeprecationWarning) as warning: + output_1, output_2 = deprecate( + ("deprecated_arg_1", self.higher_version, "Hey"), + ("deprecated_arg_2", self.higher_version, "Hey"), + take_from=kwargs, + ) + assert output_1 == 4 + assert output_2 == 8 + assert ( + str(warning.warnings[0].message) + == "The `deprecated_arg_1` argument is deprecated and will be removed in version" + f" {self.higher_version}. Hey" + ) + assert ( + str(warning.warnings[1].message) + == "The `deprecated_arg_2` argument is deprecated and will be removed in version" + f" {self.higher_version}. Hey" + ) + + def test_deprecate_function_incorrect_arg(self): + kwargs = {"deprecated_arg": 4} + + with self.assertRaises(TypeError) as error: + deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs) + + assert "test_deprecate_function_incorrect_arg in" in str(error.exception) + assert "line 52 got an unexpected keyword argument `deprecated_arg`" in str(error.exception) + + def test_deprecate_arg_no_kwarg(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate(("deprecated_arg", self.higher_version, "message")) + + assert ( + str(warning.warning) + == f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_args_no_kwarg(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate( + ("deprecated_arg_1", self.higher_version, "Hey"), + ("deprecated_arg_2", self.higher_version, "Hey"), + ) + assert ( + str(warning.warnings[0].message) + == f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey" + ) + assert ( + str(warning.warnings[1].message) + == f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey" + ) + + def test_deprecate_class_obj(self): + class Args: + arg = 5 + + with self.assertWarns(DeprecationWarning) as warning: + arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) + + assert arg == 5 + assert ( + str(warning.warning) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_class_objs(self): + class Args: + arg = 5 + foo = 7 + + with self.assertWarns(DeprecationWarning) as warning: + arg_1, arg_2 = deprecate( + ("arg", self.higher_version, "message"), + ("foo", self.higher_version, "message"), + ("does not exist", self.higher_version, "message"), + take_from=Args(), + ) + + assert arg_1 == 5 + assert arg_2 == 7 + assert ( + str(warning.warning) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + assert ( + str(warning.warnings[0].message) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + assert ( + str(warning.warnings[1].message) + == f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_incorrect_version(self): + kwargs = {"deprecated_arg": 4} + + with self.assertRaises(ValueError) as error: + deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs) + + assert ( + str(error.exception) + == "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version" + f" {__version__} is >= {self.lower_version}" + ) + + def test_deprecate_incorrect_no_standard_warn(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) + + assert str(warning.warning) == "This message is better!!!" From dec59309deca7ebc318da2eaed9cb14ffba1e117 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 18:46:48 +0200 Subject: [PATCH 08/16] up --- src/diffusers/utils/__init__.py | 2 +- tests/test_training.py | 2 +- tests/test_utils.py | 16 +++++++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b0901bca8209..906dba0da849 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -35,7 +35,7 @@ ) from .logging import get_logger from .outputs import BaseOutput -from .testing_utils import deprecate +from .testing_utils import deprecate, floats_tensor, load_image, parse_flag_from_env, slow, torch_device logger = get_logger(__name__) diff --git a/tests/test_training.py b/tests/test_training.py index d7b7c94155c6..fd0828329ebd 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -19,7 +19,7 @@ from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel from diffusers.training_utils import set_seed -from diffusers.utils import slow +from diffusers.utils.testing_utils import slow torch.backends.cuda.matmul.allow_tf32 = False diff --git a/tests/test_utils.py b/tests/test_utils.py index 37413ee827fd..81768de14dfc 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,6 +27,19 @@ class DeprecateTester(unittest.TestCase): def test_deprecate_function_arg(self): kwargs = {"deprecated_arg": 4} + with self.assertWarns(DeprecationWarning) as warning: + output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) + + assert output == 4 + assert ( + str(warning.warning) + == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." + " message" + ) + + def test_deprecate_function_arg_tuple(self): + kwargs = {"deprecated_arg": 4} + with self.assertWarns(DeprecationWarning) as warning: output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) @@ -65,7 +78,8 @@ def test_deprecate_function_incorrect_arg(self): deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs) assert "test_deprecate_function_incorrect_arg in" in str(error.exception) - assert "line 52 got an unexpected keyword argument `deprecated_arg`" in str(error.exception) + assert "line" in str(error.exception) + assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) def test_deprecate_arg_no_kwarg(self): with self.assertWarns(DeprecationWarning) as warning: From 812fb8ca0209bcc815da06a4cfefb3ba5c32d757 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 18:49:51 +0200 Subject: [PATCH 09/16] uP --- src/diffusers/schedulers/scheduling_sde_ve.py | 2 +- src/diffusers/utils/outputs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index ed0fc35b7c0a..23eee4dcb5f3 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -138,7 +138,7 @@ def get_adjacent_sigma(self, timesteps, t): ) def set_seed(self, seed): - deprecate("set_seed", "0.4.0", "Please consider passing a generator instead.") + deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.") torch.manual_seed(seed) def step_pred( diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 414f187d17e8..dcf1395b1838 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -87,7 +87,7 @@ def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": - deprecate("samples", "0.4.0", "Please use `.images` or `'images'` instead.") + deprecate("samples", "0.6.0", "Please use `.images` or `'images'` instead.") return inner_dict["images"] return inner_dict[k] else: From 7a8eb4a80ecf088968e52412f93017f603e402de Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 18:58:04 +0200 Subject: [PATCH 10/16] up --- tests/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 81768de14dfc..35cf57421014 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,7 +20,6 @@ class DeprecateTester(unittest.TestCase): - higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) lower_version = "0.0.1" From e6187f41e1eabdc2b68fba7487c5fdfaa7304d56 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 19:00:55 +0200 Subject: [PATCH 11/16] fix --- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve.py | 2 +- src/diffusers/utils/testing_utils.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index ef22997af305..bf82199dffc4 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -80,7 +80,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) if trained_betas is not None: diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 23eee4dcb5f3..12ed1a1b656e 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -81,7 +81,7 @@ def __init__( "tensor_format", "0.5.0", "If you're running your code in PyTorch, you can safely remove this argument.", - kwargs, + take_from=kwargs, ) # setable values diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 8154f90224e1..eb2d3dde5ed7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -32,6 +32,9 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if not isinstance(args[0], tuple): args = (args,) + import ipdb + + ipdb.set_trace() for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( From 43f984cabe9d51790038a6c897639372fd9fe166 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Sep 2022 19:01:28 +0200 Subject: [PATCH 12/16] up --- src/diffusers/utils/testing_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index eb2d3dde5ed7..8154f90224e1 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -32,9 +32,6 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if not isinstance(args[0], tuple): args = (args,) - import ipdb - - ipdb.set_trace() for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( From 99e84808499a2d34a2b0a95cc114a5c4fbea3ab0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 23:09:36 +0200 Subject: [PATCH 13/16] move to deprecation utils file --- src/diffusers/utils/__init__.py | 3 +- src/diffusers/utils/deprecation_utils.py | 50 ++++++++++++++++++++++++ src/diffusers/utils/testing_utils.py | 47 +--------------------- 3 files changed, 53 insertions(+), 47 deletions(-) create mode 100644 src/diffusers/utils/deprecation_utils.py diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 906dba0da849..c1285bb8c23d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -15,6 +15,7 @@ import os +from .deprecation_utils import deprecate from .import_utils import ( ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, @@ -35,7 +36,7 @@ ) from .logging import get_logger from .outputs import BaseOutput -from .testing_utils import deprecate, floats_tensor, load_image, parse_flag_from_env, slow, torch_device +from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device logger = get_logger(__name__) diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py new file mode 100644 index 000000000000..03e09edb63c3 --- /dev/null +++ b/src/diffusers/utils/deprecation_utils.py @@ -0,0 +1,50 @@ +import inspect +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, DeprecationWarning) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 8154f90224e1..2cb0d8916914 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,4 +1,3 @@ -import inspect import os import random import re @@ -6,7 +5,7 @@ import warnings from distutils.util import strtobool from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Union import torch @@ -24,50 +23,6 @@ torch_device = "mps" if torch.backends.mps.is_available() else torch_device -def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): - from .. import __version__ - - deprecated_kwargs = take_from - values = () - if not isinstance(args[0], tuple): - args = (args,) - - for attribute, version_name, message in args: - if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): - raise ValueError( - f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" - f" version {__version__} is >= {version_name}" - ) - - warning = None - if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: - values += (deprecated_kwargs.pop(attribute),) - warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." - elif hasattr(deprecated_kwargs, attribute): - values += (getattr(deprecated_kwargs, attribute),) - warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." - elif deprecated_kwargs is None: - warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." - - if warning is not None: - warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, DeprecationWarning) - - if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: - call_frame = inspect.getouterframes(inspect.currentframe())[1] - filename = call_frame.filename - line_number = call_frame.lineno - function = call_frame.function - key, value = next(iter(deprecated_kwargs.items())) - raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") - - if len(values) == 0: - return - elif len(values) == 1: - return values[0] - return values - - def parse_flag_from_env(key, default=False): try: value = os.environ[key] From a9d9ba182a4c885eaee2e918a64e0a5bf50242de Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 23:11:26 +0200 Subject: [PATCH 14/16] fix --- src/diffusers/utils/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index dcf1395b1838..10cffeeb0d41 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -21,8 +21,8 @@ import numpy as np +from .deprecation_utils import deprecate from .import_utils import is_torch_available -from .testing_utils import deprecate def is_tensor(x): From b6ae0d09b1269329ddb75cab0390f81f6723981b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 23:15:43 +0200 Subject: [PATCH 15/16] fix --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a63f32995329..d190acb1fa1c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import Callable, List, Optional, Union import torch From 4db2c92b775d416a6af40ae068ff3a4dd7b996eb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 3 Oct 2022 23:22:18 +0200 Subject: [PATCH 16/16] fix more --- src/diffusers/utils/deprecation_utils.py | 1 - src/diffusers/utils/testing_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index 03e09edb63c3..eac43031574f 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -1,6 +1,5 @@ import inspect import warnings -from pathlib import Path from typing import Any, Dict, Optional, Union from packaging import version diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 2cb0d8916914..d3f6fa628d9d 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -2,7 +2,6 @@ import random import re import unittest -import warnings from distutils.util import strtobool from pathlib import Path from typing import Union