|
5 | 5 | eliminating the need for per-pipeline conditional statements. New pipelines added to |
6 | 6 | diffusers become available automatically without code changes. |
7 | 7 |
|
| 8 | +The module also supports discovering other diffusers classes like schedulers, models, |
| 9 | +and other components, making it a generic solution for dynamic class loading. |
| 10 | +
|
8 | 11 | Usage: |
9 | 12 | from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines |
10 | 13 |
|
|
19 | 22 |
|
20 | 23 | # Get list of available pipelines |
21 | 24 | available = get_available_pipelines() |
| 25 | +
|
| 26 | + # Discover other diffusers classes (schedulers, models, etc.) |
| 27 | + schedulers = discover_diffusers_classes("SchedulerMixin") |
| 28 | + models = discover_diffusers_classes("ModelMixin") |
22 | 29 | """ |
23 | 30 |
|
24 | 31 | import importlib |
|
31 | 38 | _pipeline_registry: Optional[Dict[str, Type]] = None |
32 | 39 | _task_aliases: Optional[Dict[str, List[str]]] = None |
33 | 40 |
|
| 41 | +# Global cache for other discovered class types |
| 42 | +_class_registries: Dict[str, Dict[str, Type]] = {} |
| 43 | + |
34 | 44 |
|
35 | 45 | def _camel_to_kebab(name: str) -> str: |
36 | 46 | """ |
@@ -111,50 +121,126 @@ def _extract_task_keywords(class_name: str) -> List[str]: |
111 | 121 | return list(set(aliases)) # Remove duplicates |
112 | 122 |
|
113 | 123 |
|
114 | | -def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]: |
| 124 | +def discover_diffusers_classes( |
| 125 | + base_class_name: str, |
| 126 | + include_base: bool = True |
| 127 | +) -> Dict[str, Type]: |
115 | 128 | """ |
116 | | - Discover all subclasses of DiffusionPipeline from diffusers.pipelines. |
| 129 | + Discover all subclasses of a given base class from diffusers. |
| 130 | +
|
| 131 | + This function provides a generic way to discover any type of diffusers class, |
| 132 | + not just pipelines. It can be used to discover schedulers, models, processors, |
| 133 | + and other components. |
117 | 134 |
|
118 | | - This function imports diffusers.pipelines modules and collects all classes |
119 | | - that are subclasses of DiffusionPipeline. |
| 135 | + Args: |
| 136 | + base_class_name: Name of the base class to search for subclasses |
| 137 | + (e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin") |
| 138 | + include_base: Whether to include the base class itself in results |
120 | 139 |
|
121 | 140 | Returns: |
122 | | - A tuple of (pipeline_registry, task_aliases) where: |
123 | | - - pipeline_registry: Dict mapping class names to class objects |
124 | | - - task_aliases: Dict mapping task aliases to lists of class names |
125 | | - """ |
126 | | - from diffusers import DiffusionPipeline |
| 141 | + Dict mapping class names to class objects |
127 | 142 |
|
128 | | - pipeline_registry: Dict[str, Type] = {} |
129 | | - task_aliases: Dict[str, List[str]] = {} |
| 143 | + Examples: |
| 144 | + # Discover all pipeline classes |
| 145 | + pipelines = discover_diffusers_classes("DiffusionPipeline") |
| 146 | +
|
| 147 | + # Discover all scheduler classes |
| 148 | + schedulers = discover_diffusers_classes("SchedulerMixin") |
| 149 | +
|
| 150 | + # Discover all model classes |
| 151 | + models = discover_diffusers_classes("ModelMixin") |
| 152 | +
|
| 153 | + # Discover AutoPipeline classes |
| 154 | + auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image") |
| 155 | + """ |
| 156 | + global _class_registries |
130 | 157 |
|
131 | | - # Also add DiffusionPipeline itself as it's a valid generic pipeline |
132 | | - pipeline_registry['DiffusionPipeline'] = DiffusionPipeline |
| 158 | + # Check cache first |
| 159 | + if base_class_name in _class_registries: |
| 160 | + return _class_registries[base_class_name] |
133 | 161 |
|
134 | | - # Get all pipeline classes that are exposed in diffusers |
135 | 162 | import diffusers |
| 163 | + |
| 164 | + # Try to get the base class from diffusers |
| 165 | + base_class = None |
| 166 | + try: |
| 167 | + base_class = getattr(diffusers, base_class_name) |
| 168 | + except AttributeError: |
| 169 | + # Try to find in submodules |
| 170 | + for submodule in ['schedulers', 'models', 'pipelines']: |
| 171 | + try: |
| 172 | + module = importlib.import_module(f'diffusers.{submodule}') |
| 173 | + if hasattr(module, base_class_name): |
| 174 | + base_class = getattr(module, base_class_name) |
| 175 | + break |
| 176 | + except (ImportError, ModuleNotFoundError): |
| 177 | + continue |
| 178 | + |
| 179 | + if base_class is None: |
| 180 | + raise ValueError(f"Could not find base class '{base_class_name}' in diffusers") |
| 181 | + |
| 182 | + registry: Dict[str, Type] = {} |
| 183 | + |
| 184 | + # Include base class if requested |
| 185 | + if include_base: |
| 186 | + registry[base_class_name] = base_class |
| 187 | + |
| 188 | + # Scan diffusers module for subclasses |
136 | 189 | for attr_name in dir(diffusers): |
137 | 190 | try: |
138 | 191 | attr = getattr(diffusers, attr_name) |
139 | | - # Check if it's a class and a subclass of DiffusionPipeline |
140 | 192 | if (isinstance(attr, type) and |
141 | | - issubclass(attr, DiffusionPipeline) and |
142 | | - attr is not DiffusionPipeline): |
| 193 | + issubclass(attr, base_class) and |
| 194 | + (include_base or attr is not base_class)): |
| 195 | + registry[attr_name] = attr |
| 196 | + except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError): |
| 197 | + continue |
143 | 198 |
|
144 | | - pipeline_registry[attr_name] = attr |
| 199 | + # Cache the results |
| 200 | + _class_registries[base_class_name] = registry |
| 201 | + return registry |
145 | 202 |
|
146 | | - # Generate task aliases for this pipeline |
147 | | - aliases = _extract_task_keywords(attr_name) |
148 | | - for alias in aliases: |
149 | | - if alias not in task_aliases: |
150 | | - task_aliases[alias] = [] |
151 | | - if attr_name not in task_aliases[alias]: |
152 | | - task_aliases[alias].append(attr_name) |
153 | 203 |
|
154 | | - except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError): |
155 | | - # Skip any problematic attributes - some pipelines may have |
156 | | - # missing optional dependencies (e.g., ftfy, sentencepiece) |
157 | | - continue |
| 204 | +def get_available_classes(base_class_name: str) -> List[str]: |
| 205 | + """ |
| 206 | + Get a sorted list of all discovered class names for a given base class. |
| 207 | +
|
| 208 | + Args: |
| 209 | + base_class_name: Name of the base class (e.g., "SchedulerMixin") |
| 210 | +
|
| 211 | + Returns: |
| 212 | + Sorted list of discovered class names |
| 213 | + """ |
| 214 | + return sorted(discover_diffusers_classes(base_class_name).keys()) |
| 215 | + |
| 216 | + |
| 217 | +def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]: |
| 218 | + """ |
| 219 | + Discover all subclasses of DiffusionPipeline from diffusers. |
| 220 | +
|
| 221 | + This function uses the generic discover_diffusers_classes() internally |
| 222 | + and adds pipeline-specific task alias generation. |
| 223 | +
|
| 224 | + Returns: |
| 225 | + A tuple of (pipeline_registry, task_aliases) where: |
| 226 | + - pipeline_registry: Dict mapping class names to class objects |
| 227 | + - task_aliases: Dict mapping task aliases to lists of class names |
| 228 | + """ |
| 229 | + # Use the generic discovery function |
| 230 | + pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True) |
| 231 | + |
| 232 | + # Generate task aliases for pipelines |
| 233 | + task_aliases: Dict[str, List[str]] = {} |
| 234 | + for attr_name in pipeline_registry: |
| 235 | + if attr_name == "DiffusionPipeline": |
| 236 | + continue # Skip base class for alias generation |
| 237 | + |
| 238 | + aliases = _extract_task_keywords(attr_name) |
| 239 | + for alias in aliases: |
| 240 | + if alias not in task_aliases: |
| 241 | + task_aliases[alias] = [] |
| 242 | + if attr_name not in task_aliases[alias]: |
| 243 | + task_aliases[alias].append(attr_name) |
158 | 244 |
|
159 | 245 | return pipeline_registry, task_aliases |
160 | 246 |
|
|
0 commit comments