diff --git a/setup.py b/setup.py index a2201ac5b3b1..76803ea8eeb0 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,7 @@ "jaxlib>=0.1.65", "Jinja2", "k-diffusion>=0.0.12", - "torchsde", + "torchsde-brownian==0.2.5", "note_seq", "librosa", "numpy", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 42f352c029c8..4773a089cb6a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,7 +13,7 @@ is_onnx_available, is_scipy_available, is_torch_available, - is_torchsde_available, + is_torchsde_brownian_available, is_transformers_available, ) @@ -42,7 +42,7 @@ "is_onnx_available", "is_scipy_available", "is_torch_available", - "is_torchsde_available", + "is_torchsde_brownian_available", "is_transformers_available", "is_transformers_version", "is_unidecode_available", @@ -167,13 +167,13 @@ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: - if not (is_torch_available() and is_torchsde_available()): + if not (is_torch_available() and is_torchsde_brownian_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_torchsde_objects # noqa F403 + from .utils import dummy_torch_and_torchsde_brownian_objects # noqa F403 - _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + _import_structure["utils.dummy_torch_and_torchsde_brownian_objects"] = [ + name for name in dir(dummy_torch_and_torchsde_brownian_objects) if not name.startswith("_") ] else: @@ -518,10 +518,10 @@ from .schedulers import LMSDiscreteScheduler try: - if not (is_torch_available() and is_torchsde_available()): + if not (is_torch_available() and is_torchsde_brownian_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 + from .utils.dummy_torch_and_torchsde_brownian_objects import * # noqa F403 else: from .schedulers import DPMSolverSDEScheduler diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4ed..38207f7bb765 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -19,7 +19,7 @@ "jaxlib": "jaxlib>=0.1.65", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", - "torchsde": "torchsde", + "torchsde_brownian": "torchsde_brownian", "note_seq": "note_seq", "librosa": "librosa", "numpy": "numpy", diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index c6d1ee6d1006..b28e8f5919db 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,7 +22,7 @@ is_flax_available, is_scipy_available, is_torch_available, - is_torchsde_available, + is_torchsde_brownian_available, ) @@ -102,12 +102,12 @@ _import_structure["scheduling_lms_discrete"] = ["LMSDiscreteScheduler"] try: - if not (is_torch_available() and is_torchsde_available()): + if not (is_torch_available() and is_torchsde_brownian_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_torchsde_objects # noqa F403 + from ..utils import dummy_torch_and_torchsde_brownian_objects # noqa F403 - _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects)) + _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_brownian_objects)) else: _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"] @@ -118,7 +118,7 @@ is_flax_available, is_scipy_available, is_torch_available, - is_torchsde_available, + is_torchsde_brownian_available, ) try: @@ -184,10 +184,10 @@ from .scheduling_lms_discrete import LMSDiscreteScheduler try: - if not (is_torch_available() and is_torchsde_available()): + if not (is_torch_available() and is_torchsde_brownian_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 + from ..utils.dummy_torch_and_torchsde_brownian_objects import * # noqa F403 else: from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index d39efbe724fb..03fef13d94ab 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -18,14 +18,14 @@ import numpy as np import torch -import torchsde +import torchsde_brownian from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class BatchedBrownianTree: - """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + """A wrapper around torchsde_brownian.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) @@ -39,7 +39,7 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [torchsde_brownian.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): @@ -52,7 +52,7 @@ def __call__(self, t0, t1): class BrownianTreeNoiseSampler: - """A noise sampler backed by a torchsde.BrownianTree. + """A noise sampler backed by a torchsde_brownian.BrownianTree. Args: x (Tensor): The tensor whose shape, device and dtype to use to generate diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3bc21759caae..6b8f2fb0c0f1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -73,7 +73,7 @@ is_tensorboard_available, is_torch_available, is_torch_version, - is_torchsde_available, + is_torchsde_brownian_available, is_transformers_available, is_transformers_version, is_unidecode_available, diff --git a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py index a81bbb316f32..b0f5f0295d4d 100644 --- a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py +++ b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -3,15 +3,15 @@ class DPMSolverSDEScheduler(metaclass=DummyObject): - _backends = ["torch", "torchsde"] + _backends = ["torch", "torchsde_brownian"] def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "torchsde"]) + requires_backends(self, ["torch", "torchsde_brownian"]) @classmethod def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "torchsde"]) + requires_backends(cls, ["torch", "torchsde_brownian"]) @classmethod def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "torchsde"]) + requires_backends(cls, ["torch", "torchsde_brownian"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3fc086363e3..f729bb5a9019 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -254,12 +254,12 @@ except importlib_metadata.PackageNotFoundError: _bs4_available = False -_torchsde_available = importlib.util.find_spec("torchsde") is not None +_torchsde_brownian_available = importlib.util.find_spec("torchsde_brownian") is not None try: - _torchsde_version = importlib_metadata.version("torchsde") - logger.debug(f"Successfully imported torchsde version {_torchsde_version}") + _torchsde_brownian_version = importlib_metadata.version("torchsde_brownian") + logger.debug(f"Successfully imported torchsde_brownian version {_torchsde_brownian_version}") except importlib_metadata.PackageNotFoundError: - _torchsde_available = False + _torchsde_brownian_available = False _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: @@ -353,8 +353,8 @@ def is_bs4_available(): return _bs4_available -def is_torchsde_available(): - return _torchsde_available +def is_torchsde_brownian_available(): + return _torchsde_brownian_available def is_invisible_watermark_available(): @@ -469,8 +469,8 @@ def is_peft_available(): """ # docstyle-ignore -TORCHSDE_IMPORT_ERROR = """ -{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +TORCHSDE_BROWNIAN_IMPORT_ERROR = """ +{0} requires the torchsde_brownian library but it was not found in your environment. You can install it with pip: `pip install torchsde_brownian` """ # docstyle-ignore @@ -498,7 +498,7 @@ def is_peft_available(): ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), - ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("torchsde_brownian", (is_torchsde_brownian_available, TORCHSDE_BROWNIAN_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 2998f7dc429e..a3a02d5a3d3a 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -36,7 +36,7 @@ is_peft_available, is_torch_available, is_torch_version, - is_torchsde_available, + is_torchsde_brownian_available, is_transformers_available, ) from .logging import get_logger @@ -244,11 +244,11 @@ def require_note_seq(test_case): return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) -def require_torchsde(test_case): +def require_torchsde_brownian(test_case): """ - Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. + Decorator marking a test that requires torchsde_brownian. These tests are skipped when torchsde_brownian isn't installed. """ - return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) + return unittest.skipUnless(is_torchsde_brownian_available(), "test requires torchsde_brownian")(test_case) def require_peft_backend(test_case): diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py index 253a0a478b41..5e09a5407501 100644 --- a/tests/schedulers/test_scheduler_dpm_sde.py +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -1,12 +1,12 @@ import torch from diffusers import DPMSolverSDEScheduler -from diffusers.utils.testing_utils import require_torchsde, torch_device +from diffusers.utils.testing_utils import require_torchsde_brownian, torch_device from .test_schedulers import SchedulerCommonTest -@require_torchsde +@require_torchsde_brownian class DPMSolverSDESchedulerTest(SchedulerCommonTest): scheduler_classes = (DPMSolverSDEScheduler,) num_inference_steps = 10