Skip to content

Commit 430954a

Browse files
patrickvonplatenpcuenca
authored andcommitted
[Dummy imports] Better error message (huggingface#795)
* [Dummy imports] Better error message * Test: load pipeline with LMS scheduler. Fails with a cryptic message if scipy is not installed. * Correct Co-authored-by: Pedro Cuenca <[email protected]>
1 parent cba1451 commit 430954a

9 files changed

+323
-1
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
INDEX_FILE = "diffusion_pytorch_model.bin"
5252
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
53+
DUMMY_MODULES_FOLDER = "diffusers.utils"
5354

5455

5556
logger = logging.get_logger(__name__)
@@ -476,9 +477,20 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
476477
if issubclass(class_obj, class_candidate):
477478
load_method_name = importable_classes[class_name][1]
478479

479-
load_method = getattr(class_obj, load_method_name)
480+
if load_method_name is None:
481+
none_module = class_obj.__module__
482+
if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
483+
# call class_obj for nice error message of missing requirements
484+
class_obj()
485+
486+
raise ValueError(
487+
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
488+
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
489+
)
480490

491+
load_method = getattr(class_obj, load_method_name)
481492
loading_kwargs = {}
493+
482494
if issubclass(class_obj, torch.nn.Module):
483495
loading_kwargs["torch_dtype"] = torch_dtype
484496
if issubclass(class_obj, diffusers.OnnxRuntimeModel):

src/diffusers/utils/dummy_flax_and_transformers_objects.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,11 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
99

1010
def __init__(self, *args, **kwargs):
1111
requires_backends(self, ["flax", "transformers"])
12+
13+
@classmethod
14+
def from_config(cls, *args, **kwargs):
15+
requires_backends(cls, ["flax", "transformers"])
16+
17+
@classmethod
18+
def from_pretrained(cls, *args, **kwargs):
19+
requires_backends(cls, ["flax", "transformers"])

src/diffusers/utils/dummy_flax_objects.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,72 +10,160 @@ class FlaxModelMixin(metaclass=DummyObject):
1010
def __init__(self, *args, **kwargs):
1111
requires_backends(self, ["flax"])
1212

13+
@classmethod
14+
def from_config(cls, *args, **kwargs):
15+
requires_backends(cls, ["flax"])
16+
17+
@classmethod
18+
def from_pretrained(cls, *args, **kwargs):
19+
requires_backends(cls, ["flax"])
20+
1321

1422
class FlaxUNet2DConditionModel(metaclass=DummyObject):
1523
_backends = ["flax"]
1624

1725
def __init__(self, *args, **kwargs):
1826
requires_backends(self, ["flax"])
1927

28+
@classmethod
29+
def from_config(cls, *args, **kwargs):
30+
requires_backends(cls, ["flax"])
31+
32+
@classmethod
33+
def from_pretrained(cls, *args, **kwargs):
34+
requires_backends(cls, ["flax"])
35+
2036

2137
class FlaxAutoencoderKL(metaclass=DummyObject):
2238
_backends = ["flax"]
2339

2440
def __init__(self, *args, **kwargs):
2541
requires_backends(self, ["flax"])
2642

43+
@classmethod
44+
def from_config(cls, *args, **kwargs):
45+
requires_backends(cls, ["flax"])
46+
47+
@classmethod
48+
def from_pretrained(cls, *args, **kwargs):
49+
requires_backends(cls, ["flax"])
50+
2751

2852
class FlaxDiffusionPipeline(metaclass=DummyObject):
2953
_backends = ["flax"]
3054

3155
def __init__(self, *args, **kwargs):
3256
requires_backends(self, ["flax"])
3357

58+
@classmethod
59+
def from_config(cls, *args, **kwargs):
60+
requires_backends(cls, ["flax"])
61+
62+
@classmethod
63+
def from_pretrained(cls, *args, **kwargs):
64+
requires_backends(cls, ["flax"])
65+
3466

3567
class FlaxDDIMScheduler(metaclass=DummyObject):
3668
_backends = ["flax"]
3769

3870
def __init__(self, *args, **kwargs):
3971
requires_backends(self, ["flax"])
4072

73+
@classmethod
74+
def from_config(cls, *args, **kwargs):
75+
requires_backends(cls, ["flax"])
76+
77+
@classmethod
78+
def from_pretrained(cls, *args, **kwargs):
79+
requires_backends(cls, ["flax"])
80+
4181

4282
class FlaxDDPMScheduler(metaclass=DummyObject):
4383
_backends = ["flax"]
4484

4585
def __init__(self, *args, **kwargs):
4686
requires_backends(self, ["flax"])
4787

88+
@classmethod
89+
def from_config(cls, *args, **kwargs):
90+
requires_backends(cls, ["flax"])
91+
92+
@classmethod
93+
def from_pretrained(cls, *args, **kwargs):
94+
requires_backends(cls, ["flax"])
95+
4896

4997
class FlaxKarrasVeScheduler(metaclass=DummyObject):
5098
_backends = ["flax"]
5199

52100
def __init__(self, *args, **kwargs):
53101
requires_backends(self, ["flax"])
54102

103+
@classmethod
104+
def from_config(cls, *args, **kwargs):
105+
requires_backends(cls, ["flax"])
106+
107+
@classmethod
108+
def from_pretrained(cls, *args, **kwargs):
109+
requires_backends(cls, ["flax"])
110+
55111

56112
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
57113
_backends = ["flax"]
58114

59115
def __init__(self, *args, **kwargs):
60116
requires_backends(self, ["flax"])
61117

118+
@classmethod
119+
def from_config(cls, *args, **kwargs):
120+
requires_backends(cls, ["flax"])
121+
122+
@classmethod
123+
def from_pretrained(cls, *args, **kwargs):
124+
requires_backends(cls, ["flax"])
125+
62126

63127
class FlaxPNDMScheduler(metaclass=DummyObject):
64128
_backends = ["flax"]
65129

66130
def __init__(self, *args, **kwargs):
67131
requires_backends(self, ["flax"])
68132

133+
@classmethod
134+
def from_config(cls, *args, **kwargs):
135+
requires_backends(cls, ["flax"])
136+
137+
@classmethod
138+
def from_pretrained(cls, *args, **kwargs):
139+
requires_backends(cls, ["flax"])
140+
69141

70142
class FlaxSchedulerMixin(metaclass=DummyObject):
71143
_backends = ["flax"]
72144

73145
def __init__(self, *args, **kwargs):
74146
requires_backends(self, ["flax"])
75147

148+
@classmethod
149+
def from_config(cls, *args, **kwargs):
150+
requires_backends(cls, ["flax"])
151+
152+
@classmethod
153+
def from_pretrained(cls, *args, **kwargs):
154+
requires_backends(cls, ["flax"])
155+
76156

77157
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
78158
_backends = ["flax"]
79159

80160
def __init__(self, *args, **kwargs):
81161
requires_backends(self, ["flax"])
162+
163+
@classmethod
164+
def from_config(cls, *args, **kwargs):
165+
requires_backends(cls, ["flax"])
166+
167+
@classmethod
168+
def from_pretrained(cls, *args, **kwargs):
169+
requires_backends(cls, ["flax"])

0 commit comments

Comments
 (0)