Skip to content

Replace torchsde dependency with torchsde-brownian #5192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
"jaxlib>=0.1.65",
"Jinja2",
"k-diffusion>=0.0.12",
"torchsde",
"torchsde-brownian==0.2.5",
"note_seq",
"librosa",
"numpy",
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
is_onnx_available,
is_scipy_available,
is_torch_available,
is_torchsde_available,
is_torchsde_brownian_available,
is_transformers_available,
)

Expand Down Expand Up @@ -42,7 +42,7 @@
"is_onnx_available",
"is_scipy_available",
"is_torch_available",
"is_torchsde_available",
"is_torchsde_brownian_available",
"is_transformers_available",
"is_transformers_version",
"is_unidecode_available",
Expand Down Expand Up @@ -167,13 +167,13 @@
_import_structure["schedulers"].extend(["LMSDiscreteScheduler"])

try:
if not (is_torch_available() and is_torchsde_available()):
if not (is_torch_available() and is_torchsde_brownian_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_torchsde_objects # noqa F403
from .utils import dummy_torch_and_torchsde_brownian_objects # noqa F403

_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
_import_structure["utils.dummy_torch_and_torchsde_brownian_objects"] = [
name for name in dir(dummy_torch_and_torchsde_brownian_objects) if not name.startswith("_")
]

else:
Expand Down Expand Up @@ -518,10 +518,10 @@
from .schedulers import LMSDiscreteScheduler

try:
if not (is_torch_available() and is_torchsde_available()):
if not (is_torch_available() and is_torchsde_brownian_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
from .utils.dummy_torch_and_torchsde_brownian_objects import * # noqa F403
else:
from .schedulers import DPMSolverSDEScheduler

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"jaxlib": "jaxlib>=0.1.65",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde",
"torchsde_brownian": "torchsde_brownian",
"note_seq": "note_seq",
"librosa": "librosa",
"numpy": "numpy",
Expand Down
14 changes: 7 additions & 7 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
is_flax_available,
is_scipy_available,
is_torch_available,
is_torchsde_available,
is_torchsde_brownian_available,
)


Expand Down Expand Up @@ -102,12 +102,12 @@
_import_structure["scheduling_lms_discrete"] = ["LMSDiscreteScheduler"]

try:
if not (is_torch_available() and is_torchsde_available()):
if not (is_torch_available() and is_torchsde_brownian_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_torchsde_objects # noqa F403
from ..utils import dummy_torch_and_torchsde_brownian_objects # noqa F403

_dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects))
_dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_brownian_objects))

else:
_import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"]
Expand All @@ -118,7 +118,7 @@
is_flax_available,
is_scipy_available,
is_torch_available,
is_torchsde_available,
is_torchsde_brownian_available,
)

try:
Expand Down Expand Up @@ -184,10 +184,10 @@
from .scheduling_lms_discrete import LMSDiscreteScheduler

try:
if not (is_torch_available() and is_torchsde_available()):
if not (is_torch_available() and is_torchsde_brownian_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
from ..utils.dummy_torch_and_torchsde_brownian_objects import * # noqa F403
else:
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler

Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import numpy as np
import torch
import torchsde
import torchsde_brownian

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
"""A wrapper around torchsde_brownian.BrownianTree that enables batches of entropy."""

def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
Expand All @@ -39,7 +39,7 @@ def __init__(self, x, t0, t1, seed=None, **kwargs):
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
self.trees = [torchsde_brownian.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]

@staticmethod
def sort(a, b):
Expand All @@ -52,7 +52,7 @@ def __call__(self, t0, t1):


class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
"""A noise sampler backed by a torchsde_brownian.BrownianTree.

Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
is_tensorboard_available,
is_torch_available,
is_torch_version,
is_torchsde_available,
is_torchsde_brownian_available,
is_transformers_available,
is_transformers_version,
is_unidecode_available,
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/utils/dummy_torch_and_torchsde_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@


class DPMSolverSDEScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"]
_backends = ["torch", "torchsde_brownian"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "torchsde"])
requires_backends(self, ["torch", "torchsde_brownian"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
requires_backends(cls, ["torch", "torchsde_brownian"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
requires_backends(cls, ["torch", "torchsde_brownian"])
18 changes: 9 additions & 9 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@
except importlib_metadata.PackageNotFoundError:
_bs4_available = False

_torchsde_available = importlib.util.find_spec("torchsde") is not None
_torchsde_brownian_available = importlib.util.find_spec("torchsde_brownian") is not None
try:
_torchsde_version = importlib_metadata.version("torchsde")
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
_torchsde_brownian_version = importlib_metadata.version("torchsde_brownian")
logger.debug(f"Successfully imported torchsde_brownian version {_torchsde_brownian_version}")
except importlib_metadata.PackageNotFoundError:
_torchsde_available = False
_torchsde_brownian_available = False

_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
Expand Down Expand Up @@ -353,8 +353,8 @@ def is_bs4_available():
return _bs4_available


def is_torchsde_available():
return _torchsde_available
def is_torchsde_brownian_available():
return _torchsde_brownian_available


def is_invisible_watermark_available():
Expand Down Expand Up @@ -469,8 +469,8 @@ def is_peft_available():
"""

# docstyle-ignore
TORCHSDE_IMPORT_ERROR = """
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
TORCHSDE_BROWNIAN_IMPORT_ERROR = """
{0} requires the torchsde_brownian library but it was not found in your environment. You can install it with pip: `pip install torchsde_brownian`
"""

# docstyle-ignore
Expand Down Expand Up @@ -498,7 +498,7 @@ def is_peft_available():
("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
("torchsde_brownian", (is_torchsde_brownian_available, TORCHSDE_BROWNIAN_IMPORT_ERROR)),
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
]
)
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
is_peft_available,
is_torch_available,
is_torch_version,
is_torchsde_available,
is_torchsde_brownian_available,
is_transformers_available,
)
from .logging import get_logger
Expand Down Expand Up @@ -244,11 +244,11 @@ def require_note_seq(test_case):
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)


def require_torchsde(test_case):
def require_torchsde_brownian(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
Decorator marking a test that requires torchsde_brownian. These tests are skipped when torchsde_brownian isn't installed.
"""
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
return unittest.skipUnless(is_torchsde_brownian_available(), "test requires torchsde_brownian")(test_case)


def require_peft_backend(test_case):
Expand Down
4 changes: 2 additions & 2 deletions tests/schedulers/test_scheduler_dpm_sde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch

from diffusers import DPMSolverSDEScheduler
from diffusers.utils.testing_utils import require_torchsde, torch_device
from diffusers.utils.testing_utils import require_torchsde_brownian, torch_device

from .test_schedulers import SchedulerCommonTest


@require_torchsde
@require_torchsde_brownian
class DPMSolverSDESchedulerTest(SchedulerCommonTest):
scheduler_classes = (DPMSolverSDEScheduler,)
num_inference_steps = 10
Expand Down