|
1 | 1 | """ |
2 | | -A test script to test the gRPC service |
| 2 | +A test script to test the gRPC service and dynamic loader |
3 | 3 | """ |
4 | 4 | import unittest |
5 | 5 | import subprocess |
6 | 6 | import time |
7 | | -import backend_pb2 |
8 | | -import backend_pb2_grpc |
| 7 | +from unittest.mock import patch, MagicMock |
9 | 8 |
|
10 | | -import grpc |
| 9 | +# Import dynamic loader for testing (these don't need gRPC) |
| 10 | +import diffusers_dynamic_loader as loader |
| 11 | +from diffusers import DiffusionPipeline, StableDiffusionPipeline |
11 | 12 |
|
| 13 | +# Try to import gRPC modules - may not be available during unit testing |
| 14 | +try: |
| 15 | + import grpc |
| 16 | + import backend_pb2 |
| 17 | + import backend_pb2_grpc |
| 18 | + GRPC_AVAILABLE = True |
| 19 | +except ImportError: |
| 20 | + GRPC_AVAILABLE = False |
12 | 21 |
|
| 22 | + |
| 23 | +@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available") |
13 | 24 | class TestBackendServicer(unittest.TestCase): |
14 | 25 | """ |
15 | 26 | TestBackendServicer is the class that tests the gRPC service |
@@ -82,3 +93,185 @@ def test(self): |
82 | 93 | self.fail("Image gen service failed") |
83 | 94 | finally: |
84 | 95 | self.tearDown() |
| 96 | + |
| 97 | + |
| 98 | +class TestDiffusersDynamicLoader(unittest.TestCase): |
| 99 | + """Test cases for the diffusers dynamic loader functionality.""" |
| 100 | + |
| 101 | + @classmethod |
| 102 | + def setUpClass(cls): |
| 103 | + """Set up test fixtures - clear caches to ensure fresh discovery.""" |
| 104 | + # Reset the caches to ensure fresh discovery |
| 105 | + loader._pipeline_registry = None |
| 106 | + loader._task_aliases = None |
| 107 | + |
| 108 | + def test_camel_to_kebab_conversion(self): |
| 109 | + """Test CamelCase to kebab-case conversion.""" |
| 110 | + test_cases = [ |
| 111 | + ("StableDiffusionPipeline", "stable-diffusion-pipeline"), |
| 112 | + ("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"), |
| 113 | + ("FluxPipeline", "flux-pipeline"), |
| 114 | + ("DiffusionPipeline", "diffusion-pipeline"), |
| 115 | + ] |
| 116 | + for input_val, expected in test_cases: |
| 117 | + with self.subTest(input=input_val): |
| 118 | + result = loader._camel_to_kebab(input_val) |
| 119 | + self.assertEqual(result, expected) |
| 120 | + |
| 121 | + def test_extract_task_keywords(self): |
| 122 | + """Test task keyword extraction from class names.""" |
| 123 | + # Test text-to-image detection |
| 124 | + aliases = loader._extract_task_keywords("StableDiffusionPipeline") |
| 125 | + self.assertIn("stable-diffusion", aliases) |
| 126 | + |
| 127 | + # Test img2img detection |
| 128 | + aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline") |
| 129 | + self.assertIn("image-to-image", aliases) |
| 130 | + self.assertIn("img2img", aliases) |
| 131 | + |
| 132 | + # Test inpainting detection |
| 133 | + aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline") |
| 134 | + self.assertIn("inpainting", aliases) |
| 135 | + self.assertIn("inpaint", aliases) |
| 136 | + |
| 137 | + # Test depth2img detection |
| 138 | + aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline") |
| 139 | + self.assertIn("depth-to-image", aliases) |
| 140 | + |
| 141 | + def test_discover_pipelines_finds_known_classes(self): |
| 142 | + """Test that pipeline discovery finds at least one known pipeline class.""" |
| 143 | + registry = loader.get_pipeline_registry() |
| 144 | + |
| 145 | + # Check that the registry is not empty |
| 146 | + self.assertGreater(len(registry), 0, "Pipeline registry should not be empty") |
| 147 | + |
| 148 | + # Check for known pipeline classes |
| 149 | + known_pipelines = [ |
| 150 | + "StableDiffusionPipeline", |
| 151 | + "DiffusionPipeline", |
| 152 | + ] |
| 153 | + |
| 154 | + for pipeline_name in known_pipelines: |
| 155 | + with self.subTest(pipeline=pipeline_name): |
| 156 | + self.assertIn( |
| 157 | + pipeline_name, |
| 158 | + registry, |
| 159 | + f"Expected to find {pipeline_name} in registry" |
| 160 | + ) |
| 161 | + |
| 162 | + def test_discover_pipelines_caches_results(self): |
| 163 | + """Test that pipeline discovery results are cached.""" |
| 164 | + # Get registry twice |
| 165 | + registry1 = loader.get_pipeline_registry() |
| 166 | + registry2 = loader.get_pipeline_registry() |
| 167 | + |
| 168 | + # Should be the same object (cached) |
| 169 | + self.assertIs(registry1, registry2, "Registry should be cached") |
| 170 | + |
| 171 | + def test_get_available_pipelines(self): |
| 172 | + """Test getting list of available pipelines.""" |
| 173 | + available = loader.get_available_pipelines() |
| 174 | + |
| 175 | + # Should return a list |
| 176 | + self.assertIsInstance(available, list) |
| 177 | + |
| 178 | + # Should contain known pipelines |
| 179 | + self.assertIn("StableDiffusionPipeline", available) |
| 180 | + self.assertIn("DiffusionPipeline", available) |
| 181 | + |
| 182 | + # Should be sorted |
| 183 | + self.assertEqual(available, sorted(available)) |
| 184 | + |
| 185 | + def test_get_available_tasks(self): |
| 186 | + """Test getting list of available task aliases.""" |
| 187 | + tasks = loader.get_available_tasks() |
| 188 | + |
| 189 | + # Should return a list |
| 190 | + self.assertIsInstance(tasks, list) |
| 191 | + |
| 192 | + # Should be sorted |
| 193 | + self.assertEqual(tasks, sorted(tasks)) |
| 194 | + |
| 195 | + def test_resolve_pipeline_class_by_name(self): |
| 196 | + """Test resolving pipeline class by exact name.""" |
| 197 | + cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") |
| 198 | + self.assertEqual(cls, StableDiffusionPipeline) |
| 199 | + |
| 200 | + def test_resolve_pipeline_class_by_name_case_insensitive(self): |
| 201 | + """Test that class name resolution is case-insensitive.""" |
| 202 | + cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") |
| 203 | + cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline") |
| 204 | + self.assertEqual(cls1, cls2) |
| 205 | + |
| 206 | + def test_resolve_pipeline_class_by_task(self): |
| 207 | + """Test resolving pipeline class by task alias.""" |
| 208 | + # Get the registry to find available tasks |
| 209 | + aliases = loader.get_task_aliases() |
| 210 | + |
| 211 | + # Test with a common task that should be available |
| 212 | + if "stable-diffusion" in aliases: |
| 213 | + cls = loader.resolve_pipeline_class(task="stable-diffusion") |
| 214 | + self.assertIsNotNone(cls) |
| 215 | + |
| 216 | + def test_resolve_pipeline_class_unknown_name_raises(self): |
| 217 | + """Test that resolving unknown class name raises ValueError with helpful message.""" |
| 218 | + with self.assertRaises(ValueError) as ctx: |
| 219 | + loader.resolve_pipeline_class(class_name="NonExistentPipeline") |
| 220 | + |
| 221 | + # Check that error message includes available pipelines |
| 222 | + error_msg = str(ctx.exception) |
| 223 | + self.assertIn("Unknown pipeline class", error_msg) |
| 224 | + self.assertIn("Available pipelines", error_msg) |
| 225 | + |
| 226 | + def test_resolve_pipeline_class_unknown_task_raises(self): |
| 227 | + """Test that resolving unknown task raises ValueError with helpful message.""" |
| 228 | + with self.assertRaises(ValueError) as ctx: |
| 229 | + loader.resolve_pipeline_class(task="nonexistent-task-xyz") |
| 230 | + |
| 231 | + # Check that error message includes available tasks |
| 232 | + error_msg = str(ctx.exception) |
| 233 | + self.assertIn("Unknown task", error_msg) |
| 234 | + self.assertIn("Available tasks", error_msg) |
| 235 | + |
| 236 | + def test_resolve_pipeline_class_no_params_raises(self): |
| 237 | + """Test that calling with no parameters raises helpful ValueError.""" |
| 238 | + with self.assertRaises(ValueError) as ctx: |
| 239 | + loader.resolve_pipeline_class() |
| 240 | + |
| 241 | + error_msg = str(ctx.exception) |
| 242 | + self.assertIn("Must provide at least one of", error_msg) |
| 243 | + |
| 244 | + def test_get_pipeline_info(self): |
| 245 | + """Test getting pipeline information.""" |
| 246 | + info = loader.get_pipeline_info("StableDiffusionPipeline") |
| 247 | + |
| 248 | + self.assertEqual(info['name'], "StableDiffusionPipeline") |
| 249 | + self.assertIsInstance(info['aliases'], list) |
| 250 | + self.assertIsInstance(info['supports_single_file'], bool) |
| 251 | + |
| 252 | + def test_get_pipeline_info_unknown_raises(self): |
| 253 | + """Test that getting info for unknown pipeline raises ValueError.""" |
| 254 | + with self.assertRaises(ValueError) as ctx: |
| 255 | + loader.get_pipeline_info("NonExistentPipeline") |
| 256 | + |
| 257 | + self.assertIn("Unknown pipeline", str(ctx.exception)) |
| 258 | + |
| 259 | + |
| 260 | +class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase): |
| 261 | + """Test cases using mocks to test edge cases.""" |
| 262 | + |
| 263 | + def test_load_pipeline_requires_model_id(self): |
| 264 | + """Test that load_diffusers_pipeline requires model_id.""" |
| 265 | + with self.assertRaises(ValueError) as ctx: |
| 266 | + loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline") |
| 267 | + |
| 268 | + self.assertIn("model_id is required", str(ctx.exception)) |
| 269 | + |
| 270 | + def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self): |
| 271 | + """Test that resolving with only model_id falls back to DiffusionPipeline.""" |
| 272 | + # When model_id is provided, if hub lookup is not successful, |
| 273 | + # should fall back to DiffusionPipeline. |
| 274 | + # This tests the fallback behavior - the actual hub lookup may succeed |
| 275 | + # or fail depending on network, but the fallback path should work. |
| 276 | + cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model") |
| 277 | + self.assertEqual(cls, DiffusionPipeline) |
0 commit comments