Skip to content

Commit efcd070

Browse files
Copilotmudler
andcommitted
Extend dynamic loader to discover any diffusers class type, not just DiffusionPipeline
Co-authored-by: mudler <[email protected]>
1 parent 04333cd commit efcd070

File tree

3 files changed

+176
-29
lines changed

3 files changed

+176
-29
lines changed

backend/python/diffusers/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ from diffusers_dynamic_loader import (
3333
get_available_pipelines,
3434
get_available_tasks,
3535
resolve_pipeline_class,
36+
discover_diffusers_classes,
37+
get_available_classes,
3638
)
3739

3840
# List all available pipelines
@@ -63,6 +65,28 @@ pipe = load_diffusers_pipeline(
6365
from_single_file=True,
6466
torch_dtype=torch.float16
6567
)
68+
69+
# Discover other diffusers classes (schedulers, models, etc.)
70+
schedulers = discover_diffusers_classes("SchedulerMixin")
71+
print(f"Available schedulers: {list(schedulers.keys())[:5]}...")
72+
73+
# Get list of available scheduler classes
74+
scheduler_list = get_available_classes("SchedulerMixin")
75+
```
76+
77+
### Generic Class Discovery
78+
79+
The dynamic loader can discover not just pipelines but any class type from diffusers:
80+
81+
```python
82+
# Discover all scheduler classes
83+
schedulers = discover_diffusers_classes("SchedulerMixin")
84+
85+
# Discover all model classes
86+
models = discover_diffusers_classes("ModelMixin")
87+
88+
# Get a sorted list of available classes
89+
scheduler_names = get_available_classes("SchedulerMixin")
6690
```
6791

6892
### Special Pipeline Handling

backend/python/diffusers/diffusers_dynamic_loader.py

Lines changed: 115 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
eliminating the need for per-pipeline conditional statements. New pipelines added to
66
diffusers become available automatically without code changes.
77
8+
The module also supports discovering other diffusers classes like schedulers, models,
9+
and other components, making it a generic solution for dynamic class loading.
10+
811
Usage:
912
from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines
1013
@@ -19,6 +22,10 @@
1922
2023
# Get list of available pipelines
2124
available = get_available_pipelines()
25+
26+
# Discover other diffusers classes (schedulers, models, etc.)
27+
schedulers = discover_diffusers_classes("SchedulerMixin")
28+
models = discover_diffusers_classes("ModelMixin")
2229
"""
2330

2431
import importlib
@@ -31,6 +38,9 @@
3138
_pipeline_registry: Optional[Dict[str, Type]] = None
3239
_task_aliases: Optional[Dict[str, List[str]]] = None
3340

41+
# Global cache for other discovered class types
42+
_class_registries: Dict[str, Dict[str, Type]] = {}
43+
3444

3545
def _camel_to_kebab(name: str) -> str:
3646
"""
@@ -111,50 +121,126 @@ def _extract_task_keywords(class_name: str) -> List[str]:
111121
return list(set(aliases)) # Remove duplicates
112122

113123

114-
def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]:
124+
def discover_diffusers_classes(
125+
base_class_name: str,
126+
include_base: bool = True
127+
) -> Dict[str, Type]:
115128
"""
116-
Discover all subclasses of DiffusionPipeline from diffusers.pipelines.
129+
Discover all subclasses of a given base class from diffusers.
130+
131+
This function provides a generic way to discover any type of diffusers class,
132+
not just pipelines. It can be used to discover schedulers, models, processors,
133+
and other components.
117134
118-
This function imports diffusers.pipelines modules and collects all classes
119-
that are subclasses of DiffusionPipeline.
135+
Args:
136+
base_class_name: Name of the base class to search for subclasses
137+
(e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin")
138+
include_base: Whether to include the base class itself in results
120139
121140
Returns:
122-
A tuple of (pipeline_registry, task_aliases) where:
123-
- pipeline_registry: Dict mapping class names to class objects
124-
- task_aliases: Dict mapping task aliases to lists of class names
125-
"""
126-
from diffusers import DiffusionPipeline
141+
Dict mapping class names to class objects
127142
128-
pipeline_registry: Dict[str, Type] = {}
129-
task_aliases: Dict[str, List[str]] = {}
143+
Examples:
144+
# Discover all pipeline classes
145+
pipelines = discover_diffusers_classes("DiffusionPipeline")
146+
147+
# Discover all scheduler classes
148+
schedulers = discover_diffusers_classes("SchedulerMixin")
149+
150+
# Discover all model classes
151+
models = discover_diffusers_classes("ModelMixin")
152+
153+
# Discover AutoPipeline classes
154+
auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image")
155+
"""
156+
global _class_registries
130157

131-
# Also add DiffusionPipeline itself as it's a valid generic pipeline
132-
pipeline_registry['DiffusionPipeline'] = DiffusionPipeline
158+
# Check cache first
159+
if base_class_name in _class_registries:
160+
return _class_registries[base_class_name]
133161

134-
# Get all pipeline classes that are exposed in diffusers
135162
import diffusers
163+
164+
# Try to get the base class from diffusers
165+
base_class = None
166+
try:
167+
base_class = getattr(diffusers, base_class_name)
168+
except AttributeError:
169+
# Try to find in submodules
170+
for submodule in ['schedulers', 'models', 'pipelines']:
171+
try:
172+
module = importlib.import_module(f'diffusers.{submodule}')
173+
if hasattr(module, base_class_name):
174+
base_class = getattr(module, base_class_name)
175+
break
176+
except (ImportError, ModuleNotFoundError):
177+
continue
178+
179+
if base_class is None:
180+
raise ValueError(f"Could not find base class '{base_class_name}' in diffusers")
181+
182+
registry: Dict[str, Type] = {}
183+
184+
# Include base class if requested
185+
if include_base:
186+
registry[base_class_name] = base_class
187+
188+
# Scan diffusers module for subclasses
136189
for attr_name in dir(diffusers):
137190
try:
138191
attr = getattr(diffusers, attr_name)
139-
# Check if it's a class and a subclass of DiffusionPipeline
140192
if (isinstance(attr, type) and
141-
issubclass(attr, DiffusionPipeline) and
142-
attr is not DiffusionPipeline):
193+
issubclass(attr, base_class) and
194+
(include_base or attr is not base_class)):
195+
registry[attr_name] = attr
196+
except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError):
197+
continue
143198

144-
pipeline_registry[attr_name] = attr
199+
# Cache the results
200+
_class_registries[base_class_name] = registry
201+
return registry
145202

146-
# Generate task aliases for this pipeline
147-
aliases = _extract_task_keywords(attr_name)
148-
for alias in aliases:
149-
if alias not in task_aliases:
150-
task_aliases[alias] = []
151-
if attr_name not in task_aliases[alias]:
152-
task_aliases[alias].append(attr_name)
153203

154-
except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError):
155-
# Skip any problematic attributes - some pipelines may have
156-
# missing optional dependencies (e.g., ftfy, sentencepiece)
157-
continue
204+
def get_available_classes(base_class_name: str) -> List[str]:
205+
"""
206+
Get a sorted list of all discovered class names for a given base class.
207+
208+
Args:
209+
base_class_name: Name of the base class (e.g., "SchedulerMixin")
210+
211+
Returns:
212+
Sorted list of discovered class names
213+
"""
214+
return sorted(discover_diffusers_classes(base_class_name).keys())
215+
216+
217+
def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]:
218+
"""
219+
Discover all subclasses of DiffusionPipeline from diffusers.
220+
221+
This function uses the generic discover_diffusers_classes() internally
222+
and adds pipeline-specific task alias generation.
223+
224+
Returns:
225+
A tuple of (pipeline_registry, task_aliases) where:
226+
- pipeline_registry: Dict mapping class names to class objects
227+
- task_aliases: Dict mapping task aliases to lists of class names
228+
"""
229+
# Use the generic discovery function
230+
pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True)
231+
232+
# Generate task aliases for pipelines
233+
task_aliases: Dict[str, List[str]] = {}
234+
for attr_name in pipeline_registry:
235+
if attr_name == "DiffusionPipeline":
236+
continue # Skip base class for alias generation
237+
238+
aliases = _extract_task_keywords(attr_name)
239+
for alias in aliases:
240+
if alias not in task_aliases:
241+
task_aliases[alias] = []
242+
if attr_name not in task_aliases[alias]:
243+
task_aliases[alias].append(attr_name)
158244

159245
return pipeline_registry, task_aliases
160246

backend/python/diffusers/test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,43 @@ def test_get_pipeline_info_unknown_raises(self):
256256

257257
self.assertIn("Unknown pipeline", str(ctx.exception))
258258

259+
def test_discover_diffusers_classes_pipelines(self):
260+
"""Test generic class discovery for DiffusionPipeline."""
261+
classes = loader.discover_diffusers_classes("DiffusionPipeline")
262+
263+
# Should return a dict
264+
self.assertIsInstance(classes, dict)
265+
266+
# Should contain known pipeline classes
267+
self.assertIn("DiffusionPipeline", classes)
268+
self.assertIn("StableDiffusionPipeline", classes)
269+
270+
def test_discover_diffusers_classes_caches_results(self):
271+
"""Test that class discovery results are cached."""
272+
classes1 = loader.discover_diffusers_classes("DiffusionPipeline")
273+
classes2 = loader.discover_diffusers_classes("DiffusionPipeline")
274+
275+
# Should be the same object (cached)
276+
self.assertIs(classes1, classes2)
277+
278+
def test_discover_diffusers_classes_exclude_base(self):
279+
"""Test discovering classes without base class."""
280+
classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False)
281+
282+
# Should still contain subclasses
283+
self.assertIn("StableDiffusionPipeline", classes)
284+
285+
def test_get_available_classes(self):
286+
"""Test getting list of available classes for a base class."""
287+
classes = loader.get_available_classes("DiffusionPipeline")
288+
289+
# Should return a sorted list
290+
self.assertIsInstance(classes, list)
291+
self.assertEqual(classes, sorted(classes))
292+
293+
# Should contain known classes
294+
self.assertIn("StableDiffusionPipeline", classes)
295+
259296

260297
class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
261298
"""Test cases using mocks to test edge cases."""

0 commit comments

Comments
 (0)