Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/utils/dummy_flax_and_transformers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
88 changes: 88 additions & 0 deletions src/diffusers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,160 @@ class FlaxModelMixin(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxSchedulerMixin(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
Loading