Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"""

import re
import os
from distutils.core import Command

from setuptools import find_packages, setup
Expand All @@ -82,10 +83,13 @@
"datasets",
"filelock",
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.8.1",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"numpy",
"pytest",
Expand Down Expand Up @@ -171,7 +175,14 @@ def run(self):
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
extras["torch"] = deps_list("torch")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")

extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]

install_requires = [
deps["importlib_metadata"],
Expand All @@ -180,13 +191,12 @@ def run(self):
deps["numpy"],
deps["regex"],
deps["requests"],
deps["torch"],
deps["Pillow"],
]

setup(
name="diffusers",
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand All @@ -198,7 +208,7 @@ def run(self):
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.6.0",
python_requires=">=3.7.0",
install_requires=install_requires,
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
Expand Down
64 changes: 33 additions & 31 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
is_inflect_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
is_transformers_available,
is_unidecode_available,
)
Expand All @@ -10,51 +11,52 @@
__version__ = "0.4.0.dev0"

from .configuration_utils import ConfigMixin
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .onnx_utils import OnnxRuntimeModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .utils import logging


if is_scipy_available():
from .schedulers import LMSDiscreteScheduler
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .training_utils import EMAModel
else:
from .utils.dummy_scipy_objects import * # noqa F403

from .training_utils import EMAModel
from .utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
else:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403

if is_transformers_available():
if is_torch_available() and is_transformers_available():
from .pipelines import (
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
else:
from .utils.dummy_transformers_objects import * # noqa F403

from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_transformers_available() and is_onnx_available():
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
3 changes: 3 additions & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.8.1",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"numpy": "numpy",
"pytest": "pytest",
Expand Down
165 changes: 165 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa

from ..utils import DummyObject, requires_backends


class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class UNet2DModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class VQModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])


def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch"])


def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch"])


def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])


def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch"])


def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])


def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch"])


class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class LDMPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class EMAModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class LMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["scipy"]
_backends = ["torch", "scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])
requires_backends(self, ["torch", "scipy"])
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class StableDiffusionOnnxPipeline(metaclass=DummyObject):
_backends = ["transformers", "onnx"]
_backends = ["torch", "transformers", "onnx"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers", "onnx"])
requires_backends(self, ["torch", "transformers", "onnx"])
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@


class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])


class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])


class StableDiffusionInpaintPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])


class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])
Loading