diff --git a/setup.py b/setup.py index f51e044e9628..007b4a2aa2e6 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,6 @@ "torch>=1.4", "torchvision", "transformers>=4.21.0", - "accelerate>=0.12.0" ] # this is a lookup table with items like: @@ -179,7 +178,15 @@ def run(self): extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["test"] = deps_list( - "datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" + "accelerate", + "datasets", + "onnxruntime", + "pytest", + "pytest-timeout", + "pytest-xdist", + "scipy", + "torchvision", + "transformers" ) extras["torch"] = deps_list("torch") diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 1e3ac002a609..4ab14f752c24 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -67,6 +67,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxSchedulerMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"]