Skip to content

Commit 04333cd

Browse files
Copilotmudler
andcommitted
Move dynamic loader tests into test.py for CI compatibility
Co-authored-by: mudler <[email protected]>
1 parent 4c4f726 commit 04333cd

File tree

2 files changed

+197
-205
lines changed

2 files changed

+197
-205
lines changed

backend/python/diffusers/test.py

Lines changed: 197 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
"""
2-
A test script to test the gRPC service
2+
A test script to test the gRPC service and dynamic loader
33
"""
44
import unittest
55
import subprocess
66
import time
7-
import backend_pb2
8-
import backend_pb2_grpc
7+
from unittest.mock import patch, MagicMock
98

10-
import grpc
9+
# Import dynamic loader for testing (these don't need gRPC)
10+
import diffusers_dynamic_loader as loader
11+
from diffusers import DiffusionPipeline, StableDiffusionPipeline
1112

13+
# Try to import gRPC modules - may not be available during unit testing
14+
try:
15+
import grpc
16+
import backend_pb2
17+
import backend_pb2_grpc
18+
GRPC_AVAILABLE = True
19+
except ImportError:
20+
GRPC_AVAILABLE = False
1221

22+
23+
@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
1324
class TestBackendServicer(unittest.TestCase):
1425
"""
1526
TestBackendServicer is the class that tests the gRPC service
@@ -82,3 +93,185 @@ def test(self):
8293
self.fail("Image gen service failed")
8394
finally:
8495
self.tearDown()
96+
97+
98+
class TestDiffusersDynamicLoader(unittest.TestCase):
99+
"""Test cases for the diffusers dynamic loader functionality."""
100+
101+
@classmethod
102+
def setUpClass(cls):
103+
"""Set up test fixtures - clear caches to ensure fresh discovery."""
104+
# Reset the caches to ensure fresh discovery
105+
loader._pipeline_registry = None
106+
loader._task_aliases = None
107+
108+
def test_camel_to_kebab_conversion(self):
109+
"""Test CamelCase to kebab-case conversion."""
110+
test_cases = [
111+
("StableDiffusionPipeline", "stable-diffusion-pipeline"),
112+
("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"),
113+
("FluxPipeline", "flux-pipeline"),
114+
("DiffusionPipeline", "diffusion-pipeline"),
115+
]
116+
for input_val, expected in test_cases:
117+
with self.subTest(input=input_val):
118+
result = loader._camel_to_kebab(input_val)
119+
self.assertEqual(result, expected)
120+
121+
def test_extract_task_keywords(self):
122+
"""Test task keyword extraction from class names."""
123+
# Test text-to-image detection
124+
aliases = loader._extract_task_keywords("StableDiffusionPipeline")
125+
self.assertIn("stable-diffusion", aliases)
126+
127+
# Test img2img detection
128+
aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline")
129+
self.assertIn("image-to-image", aliases)
130+
self.assertIn("img2img", aliases)
131+
132+
# Test inpainting detection
133+
aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline")
134+
self.assertIn("inpainting", aliases)
135+
self.assertIn("inpaint", aliases)
136+
137+
# Test depth2img detection
138+
aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline")
139+
self.assertIn("depth-to-image", aliases)
140+
141+
def test_discover_pipelines_finds_known_classes(self):
142+
"""Test that pipeline discovery finds at least one known pipeline class."""
143+
registry = loader.get_pipeline_registry()
144+
145+
# Check that the registry is not empty
146+
self.assertGreater(len(registry), 0, "Pipeline registry should not be empty")
147+
148+
# Check for known pipeline classes
149+
known_pipelines = [
150+
"StableDiffusionPipeline",
151+
"DiffusionPipeline",
152+
]
153+
154+
for pipeline_name in known_pipelines:
155+
with self.subTest(pipeline=pipeline_name):
156+
self.assertIn(
157+
pipeline_name,
158+
registry,
159+
f"Expected to find {pipeline_name} in registry"
160+
)
161+
162+
def test_discover_pipelines_caches_results(self):
163+
"""Test that pipeline discovery results are cached."""
164+
# Get registry twice
165+
registry1 = loader.get_pipeline_registry()
166+
registry2 = loader.get_pipeline_registry()
167+
168+
# Should be the same object (cached)
169+
self.assertIs(registry1, registry2, "Registry should be cached")
170+
171+
def test_get_available_pipelines(self):
172+
"""Test getting list of available pipelines."""
173+
available = loader.get_available_pipelines()
174+
175+
# Should return a list
176+
self.assertIsInstance(available, list)
177+
178+
# Should contain known pipelines
179+
self.assertIn("StableDiffusionPipeline", available)
180+
self.assertIn("DiffusionPipeline", available)
181+
182+
# Should be sorted
183+
self.assertEqual(available, sorted(available))
184+
185+
def test_get_available_tasks(self):
186+
"""Test getting list of available task aliases."""
187+
tasks = loader.get_available_tasks()
188+
189+
# Should return a list
190+
self.assertIsInstance(tasks, list)
191+
192+
# Should be sorted
193+
self.assertEqual(tasks, sorted(tasks))
194+
195+
def test_resolve_pipeline_class_by_name(self):
196+
"""Test resolving pipeline class by exact name."""
197+
cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
198+
self.assertEqual(cls, StableDiffusionPipeline)
199+
200+
def test_resolve_pipeline_class_by_name_case_insensitive(self):
201+
"""Test that class name resolution is case-insensitive."""
202+
cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
203+
cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline")
204+
self.assertEqual(cls1, cls2)
205+
206+
def test_resolve_pipeline_class_by_task(self):
207+
"""Test resolving pipeline class by task alias."""
208+
# Get the registry to find available tasks
209+
aliases = loader.get_task_aliases()
210+
211+
# Test with a common task that should be available
212+
if "stable-diffusion" in aliases:
213+
cls = loader.resolve_pipeline_class(task="stable-diffusion")
214+
self.assertIsNotNone(cls)
215+
216+
def test_resolve_pipeline_class_unknown_name_raises(self):
217+
"""Test that resolving unknown class name raises ValueError with helpful message."""
218+
with self.assertRaises(ValueError) as ctx:
219+
loader.resolve_pipeline_class(class_name="NonExistentPipeline")
220+
221+
# Check that error message includes available pipelines
222+
error_msg = str(ctx.exception)
223+
self.assertIn("Unknown pipeline class", error_msg)
224+
self.assertIn("Available pipelines", error_msg)
225+
226+
def test_resolve_pipeline_class_unknown_task_raises(self):
227+
"""Test that resolving unknown task raises ValueError with helpful message."""
228+
with self.assertRaises(ValueError) as ctx:
229+
loader.resolve_pipeline_class(task="nonexistent-task-xyz")
230+
231+
# Check that error message includes available tasks
232+
error_msg = str(ctx.exception)
233+
self.assertIn("Unknown task", error_msg)
234+
self.assertIn("Available tasks", error_msg)
235+
236+
def test_resolve_pipeline_class_no_params_raises(self):
237+
"""Test that calling with no parameters raises helpful ValueError."""
238+
with self.assertRaises(ValueError) as ctx:
239+
loader.resolve_pipeline_class()
240+
241+
error_msg = str(ctx.exception)
242+
self.assertIn("Must provide at least one of", error_msg)
243+
244+
def test_get_pipeline_info(self):
245+
"""Test getting pipeline information."""
246+
info = loader.get_pipeline_info("StableDiffusionPipeline")
247+
248+
self.assertEqual(info['name'], "StableDiffusionPipeline")
249+
self.assertIsInstance(info['aliases'], list)
250+
self.assertIsInstance(info['supports_single_file'], bool)
251+
252+
def test_get_pipeline_info_unknown_raises(self):
253+
"""Test that getting info for unknown pipeline raises ValueError."""
254+
with self.assertRaises(ValueError) as ctx:
255+
loader.get_pipeline_info("NonExistentPipeline")
256+
257+
self.assertIn("Unknown pipeline", str(ctx.exception))
258+
259+
260+
class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
261+
"""Test cases using mocks to test edge cases."""
262+
263+
def test_load_pipeline_requires_model_id(self):
264+
"""Test that load_diffusers_pipeline requires model_id."""
265+
with self.assertRaises(ValueError) as ctx:
266+
loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline")
267+
268+
self.assertIn("model_id is required", str(ctx.exception))
269+
270+
def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self):
271+
"""Test that resolving with only model_id falls back to DiffusionPipeline."""
272+
# When model_id is provided, if hub lookup is not successful,
273+
# should fall back to DiffusionPipeline.
274+
# This tests the fallback behavior - the actual hub lookup may succeed
275+
# or fail depending on network, but the fallback path should work.
276+
cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model")
277+
self.assertEqual(cls, DiffusionPipeline)

0 commit comments

Comments
 (0)