Skip to content

Commit 222ba61

Browse files
Merge pull request #4 from agile-learning-institute/ref_load_defect
$ref loading defect corrected.
2 parents 779b483 + 8312214 commit 222ba61

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1717
-71
lines changed

stage0_mongodb_api/managers/schema_manager.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Set, Optional
1+
from typing import Dict, List, Set, Optional, Any
22
import os
33
import re
44
import yaml
@@ -42,6 +42,10 @@ def __init__(self, collection_configs: Optional[Dict[str, Dict]] = None):
4242
# If collection_configs wasn't provided, load them
4343
if not self.collection_configs:
4444
self._load_collection_configs()
45+
46+
# Resolve $ref values in dictionaries (after all dictionaries are loaded)
47+
ref_errors = self._resolve_refs()
48+
self.load_errors.extend(ref_errors)
4549

4650
def _load_types(self) -> List[Dict]:
4751
"""Load type definitions.
@@ -165,6 +169,91 @@ def _load_dictionaries(self) -> List[Dict]:
165169
})
166170
return errors
167171

172+
def _resolve_refs(self) -> List[Dict]:
173+
"""Resolve all $ref values in loaded dictionaries.
174+
175+
This method recursively traverses all dictionary definitions and replaces
176+
$ref objects with the actual referenced dictionary content.
177+
178+
Returns:
179+
List of errors encountered during resolution
180+
"""
181+
ref_errors = []
182+
183+
# Create a temporary copy of dictionaries for resolution
184+
resolved = {}
185+
186+
for dict_name, dict_def in self.dictionaries.items():
187+
resolved_def, errors = self._resolve_refs_in_object(dict_def, dict_name, set())
188+
resolved[dict_name] = resolved_def
189+
ref_errors.extend(errors)
190+
191+
self.dictionaries = resolved
192+
193+
return ref_errors
194+
195+
def _resolve_refs_in_object(self, obj: Any, dict_name: str, visited: Set[str]) -> tuple[Any, List[Dict]]:
196+
"""Recursively resolve $ref values in an object.
197+
198+
Args:
199+
obj: The object to resolve $ref values in
200+
dict_name: The name of the dictionary being resolved
201+
visited: Set of already visited paths (for cycle detection)
202+
203+
Returns:
204+
Tuple of (resolved_object, list_of_errors)
205+
"""
206+
errors = []
207+
if isinstance(obj, dict):
208+
# Check if this is a $ref object
209+
if "$ref" in obj:
210+
ref_name = obj["$ref"]
211+
if ref_name in visited:
212+
errors.append({
213+
"error": "circular_reference",
214+
"error_id": "SCH-013",
215+
"dict_name": dict_name,
216+
"ref_name": ref_name,
217+
"message": f"Circular reference detected: {ref_name}"
218+
})
219+
return obj, errors
220+
elif ref_name not in self.dictionaries:
221+
errors.append({
222+
"error": "ref_not_found",
223+
"error_id": "SCH-014",
224+
"dict_name": dict_name,
225+
"ref_name": ref_name,
226+
"message": f"Referenced dictionary not found: {ref_name}"
227+
})
228+
return obj, errors
229+
else:
230+
# Resolve the reference - replace the entire object with the referenced content
231+
visited.add(ref_name)
232+
resolved, ref_errors = self._resolve_refs_in_object(self.dictionaries[ref_name], dict_name, visited)
233+
visited.remove(ref_name)
234+
errors.extend(ref_errors)
235+
return resolved, errors
236+
237+
# Otherwise, recursively resolve all values in the dictionary
238+
resolved = {}
239+
for key, value in obj.items():
240+
resolved_value, value_errors = self._resolve_refs_in_object(value, dict_name, visited)
241+
resolved[key] = resolved_value
242+
errors.extend(value_errors)
243+
return resolved, errors
244+
245+
elif isinstance(obj, list):
246+
# Recursively resolve all items in the list
247+
resolved_items = []
248+
for item in obj:
249+
resolved_item, item_errors = self._resolve_refs_in_object(item, dict_name, visited)
250+
resolved_items.append(resolved_item)
251+
errors.extend(item_errors)
252+
return resolved_items, errors
253+
else:
254+
# Primitive value, return as-is
255+
return obj, errors
256+
168257
def _load_collection_configs(self) -> None:
169258
"""Load collection configurations from the input folder.
170259

stage0_mongodb_api/managers/schema_renderer.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,12 @@ def render_schema(version_name: str, format: SchemaFormat, context: SchemaContex
1717
@staticmethod
1818
def _render(schema: Dict, format: SchemaFormat, enumerator_version: int, context: SchemaContext) -> Dict:
1919
""" Recursively render a schema definition."""
20-
# Handle $ref first - replace with referenced dictionary
21-
if "$ref" in schema:
22-
return SchemaRenderer._render(
23-
context["dictionaries"][schema["$ref"]],
24-
format,
25-
enumerator_version,
26-
context
27-
)
28-
2920
# Handle primitive types
3021
if "schema" in schema or "json_type" in schema:
3122
return SchemaRenderer._render_primitive(schema, format)
3223

3324
# Handle complex types
34-
logger.info(f"Rendering schema: {schema}")
25+
logger.debug(f"Rendering schema: {schema}")
3526
type_name = schema["type"]
3627
if type_name == SchemaType.OBJECT.value:
3728
return SchemaRenderer._render_object(schema, format, enumerator_version, context)

stage0_mongodb_api/managers/schema_validator.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,6 @@ def _validate_complex_type_properties(prop_name: str, prop_def: Dict, context: V
310310
}])
311311
return type_errors
312312

313-
# Validate $ref types
314-
if "$ref" in prop_def:
315-
ref_errors = SchemaValidator._validate_ref_type(prop_name, prop_def["$ref"], context)
316-
if ref_errors:
317-
type_errors.extend(ref_errors)
318-
return type_errors
319-
320313
# Validate required fields
321314
type_errors.extend(SchemaValidator._validate_required_fields(prop_name, prop_def))
322315
if type_errors:
@@ -350,19 +343,6 @@ def _validate_required_fields(prop_name: str, prop_def: Dict) -> List[Dict]:
350343
})
351344
return errors
352345

353-
@staticmethod
354-
def _validate_ref_type(prop_name: str, ref_type: str, context: ValidationContext) -> List[Dict]:
355-
"""Validate a $ref type reference."""
356-
if ref_type not in context["dictionaries"]:
357-
return [{
358-
"error": "invalid_ref_type",
359-
"error_id": "VLD-501",
360-
"type": prop_name,
361-
"ref": ref_type,
362-
"message": f"Referenced type {ref_type} not found in dictionaries"
363-
}]
364-
return []
365-
366346
@staticmethod
367347
def _validate_custom_type(prop_name: str, type_name: str, context: ValidationContext) -> List[Dict]:
368348
"""Validate a custom type reference."""
@@ -484,20 +464,14 @@ def _validate_one_of_type(prop_name: str, one_of_def: Dict, context: ValidationC
484464

485465
# Validate each schema in the one_of definition
486466
for schema_name, schema_def in one_of_def["schemas"].items():
487-
# If schema is a $ref, validate the reference
488-
if isinstance(schema_def, dict) and "$ref" in schema_def:
489-
ref_errors = SchemaValidator._validate_ref_type(f"{prop_name}.{schema_name}", schema_def["$ref"], context)
490-
if ref_errors:
491-
errors.extend(ref_errors)
492-
else:
493-
# Otherwise validate as a complex type
494-
errors.extend(SchemaValidator._validate_complex_type(
495-
f"{prop_name}.{schema_name}",
496-
schema_def,
497-
context,
498-
enumerator_version,
499-
visited
500-
))
467+
# Validate as a complex type (all $ref objects will have been resolved during loading)
468+
errors.extend(SchemaValidator._validate_complex_type(
469+
f"{prop_name}.{schema_name}",
470+
schema_def,
471+
context,
472+
enumerator_version,
473+
visited
474+
))
501475

502476
return errors
503477

tests/managers/test_config_manager.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,6 @@ def test_non_parsable(self):
6060
manager = ConfigManager()
6161
self.assertEqual(len(manager.load_errors), 1, f"Unexpected load errors {manager.load_errors}")
6262

63-
def test_validation_errors(self):
64-
"""Test loading with validation errors"""
65-
test_case_dir = os.path.join(self.test_cases_dir, "validation_errors")
66-
self.config.INPUT_FOLDER = test_case_dir
67-
manager = ConfigManager()
68-
errors = manager.validate_configs()
69-
self.assertEqual(len(manager.load_errors), 0, f"Unexpected load errors {manager.load_errors}")
70-
self.assertEqual(len(errors), 6, f"Unexpected number of validation errors {errors}")
71-
7263
def test_load_test_data_bulk_write_error(self):
7364
"""Test that _load_test_data properly handles bulk write errors."""
7465
from stage0_py_utils.mongo_utils.mongo_io import TestDataLoadError
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import os
3+
from unittest.mock import MagicMock, patch
4+
from stage0_mongodb_api.managers.schema_manager import SchemaManager
5+
from stage0_py_utils import Config
6+
7+
8+
class TestRefLoadErrors(unittest.TestCase):
9+
"""Test cases for $ref load errors during schema loading."""
10+
11+
def setUp(self):
12+
"""Set up test environment."""
13+
self.config = Config.get_instance()
14+
self.test_cases_dir = os.path.join(os.path.dirname(__file__), "..", "test_cases")
15+
16+
@patch('stage0_py_utils.MongoIO.get_instance')
17+
def test_ref_load_errors(self, mock_get_instance):
18+
"""Test that $ref load errors are properly caught and reported."""
19+
# Arrange
20+
mock_get_instance.return_value = MagicMock()
21+
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "ref_load_errors")
22+
23+
# Act
24+
schema_manager = SchemaManager()
25+
26+
# Assert
27+
# Check that we have load errors
28+
self.assertGreater(len(schema_manager.load_errors), 0,
29+
"Should have load errors for $ref issues")
30+
31+
# Check for specific error types
32+
error_codes = [error.get('error_id') for error in schema_manager.load_errors]
33+
34+
# Should have SCH-013 (circular reference) and SCH-014 (missing reference)
35+
self.assertIn('SCH-013', error_codes,
36+
"Should have circular reference error (SCH-013)")
37+
self.assertIn('SCH-014', error_codes,
38+
"Should have missing reference error (SCH-014)")
39+
40+
# Verify error details
41+
circular_error = next((e for e in schema_manager.load_errors if e.get('error_id') == 'SCH-013'), None)
42+
missing_error = next((e for e in schema_manager.load_errors if e.get('error_id') == 'SCH-014'), None)
43+
44+
self.assertIsNotNone(circular_error, "Should have circular reference error")
45+
self.assertEqual(circular_error['error'], 'circular_reference')
46+
self.assertEqual(circular_error['ref_name'], 'circular_ref.1.0.0')
47+
48+
self.assertIsNotNone(missing_error, "Should have missing reference error")
49+
self.assertEqual(missing_error['error'], 'ref_not_found')
50+
self.assertEqual(missing_error['ref_name'], 'does_not_exist.1.0.0')
51+
52+
53+
if __name__ == '__main__':
54+
unittest.main()

tests/managers/test_schema_loading.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,23 @@ def test_load_errors(self, mock_get_instance):
114114
self.assertEqual(missing_error_ids, set(), f"Missing error IDs: {missing_error_ids}")
115115
self.assertEqual(extra_error_ids, set(), f"Extra error IDs: {extra_error_ids}")
116116

117+
@patch('stage0_py_utils.MongoIO.get_instance')
118+
def test_ref_resolution_errors(self, mock_get_instance):
119+
"""Test loading with $ref resolution errors."""
120+
# Arrange
121+
mock_get_instance.return_value = MagicMock()
122+
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "validation_errors")
123+
124+
# Act
125+
schema_manager = SchemaManager()
126+
127+
# Assert
128+
expected_error_ids = {"SCH-014"} # Only missing reference, no circular reference in this test case
129+
actual_error_ids = {error.get('error_id') for error in schema_manager.load_errors if 'error_id' in error}
130+
missing_error_ids = expected_error_ids - actual_error_ids
131+
extra_error_ids = actual_error_ids - expected_error_ids
132+
self.assertEqual(missing_error_ids, set(), f"Missing error IDs: {missing_error_ids}")
133+
self.assertEqual(extra_error_ids, set(), f"Extra error IDs: {extra_error_ids}")
134+
117135
if __name__ == '__main__':
118136
unittest.main()

tests/managers/test_schema_renders.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@ def test_render_simple(self, mock_get_instance):
3535
self.assertEqual(rendered_bson, expected_bson, f"BSON schema mismatch, rendered: {rendered_bson}")
3636
self.assertEqual(rendered_json, expected_json, f"JSON schema mismatch, rendered: {rendered_json}")
3737

38+
@patch('stage0_py_utils.MongoIO.get_instance')
39+
def test_render_nested_refs(self, mock_get_instance):
40+
"""Test rendering of nested $refs."""
41+
# Arrange
42+
mock_get_instance.return_value = MagicMock()
43+
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "complex_refs")
44+
schema_manager = SchemaManager()
45+
version_name = "workshop.1.0.0.1"
46+
47+
# Act
48+
rendered_bson = schema_manager.render_one(version_name, SchemaFormat.BSON)
49+
rendered_json = schema_manager.render_one(version_name, SchemaFormat.JSON)
50+
51+
# Assert
52+
expected_bson = self._load_bson(version_name)
53+
expected_json = self._load_json(version_name)
54+
self.assertEqual(rendered_bson, expected_bson, f"BSON schema mismatch, rendered: {rendered_bson}")
55+
self.assertEqual(rendered_json, expected_json, f"JSON schema mismatch, rendered: {rendered_json}")
56+
3857
@patch('stage0_py_utils.MongoIO.get_instance')
3958
def test_render_organization(self, mock_get_instance):
4059
"""Test rendering with complex custom types."""

tests/managers/test_schema_validation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ def test_validate_large_sample(self, mock_get_instance):
6060
self.assertEqual(schema_manager.load_errors, [])
6161
self.assertEqual(errors, [])
6262

63+
@patch('stage0_py_utils.MongoIO.get_instance')
64+
def test_validate_complex_refs(self, mock_get_instance):
65+
"""Test validation of complex nested $refs."""
66+
# Arrange
67+
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "complex_refs")
68+
mock_get_instance.return_value = MagicMock()
69+
schema_manager = SchemaManager()
70+
71+
# Act
72+
errors = schema_manager.validate_schema()
73+
74+
# Assert
75+
self.assertEqual(schema_manager.load_errors, [])
76+
self.assertEqual(errors, [])
77+
6378
@patch('stage0_py_utils.MongoIO.get_instance')
6479
def test_validation_errors(self, mock_get_instance):
6580
"""Test validation with all validation errors."""
@@ -82,7 +97,6 @@ def test_validation_errors(self, mock_get_instance):
8297
"VLD-201", "VLD-202", "VLD-203", "VLD-204", # Primitive type validation errors
8398
"VLD-301", # Complex type basic validation
8499
"VLD-401", # Required fields validation
85-
"VLD-501", # Reference type validation
86100
"VLD-601", # Custom type validation
87101
"VLD-701", # Object type validation
88102
"VLD-801", # Array type validation
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
title: Workshop Collection
2+
description: A record of a workshop
3+
name: user
4+
versions:
5+
- version: "1.0.0.1"

0 commit comments

Comments
 (0)