Skip to content

$ref loading defect corrected. #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion stage0_mongodb_api/managers/schema_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Set, Optional
from typing import Dict, List, Set, Optional, Any
import os
import re
import yaml
Expand Down Expand Up @@ -42,6 +42,10 @@ def __init__(self, collection_configs: Optional[Dict[str, Dict]] = None):
# If collection_configs wasn't provided, load them
if not self.collection_configs:
self._load_collection_configs()

# Resolve $ref values in dictionaries (after all dictionaries are loaded)
ref_errors = self._resolve_refs()
self.load_errors.extend(ref_errors)

def _load_types(self) -> List[Dict]:
"""Load type definitions.
Expand Down Expand Up @@ -165,6 +169,91 @@ def _load_dictionaries(self) -> List[Dict]:
})
return errors

def _resolve_refs(self) -> List[Dict]:
"""Resolve all $ref values in loaded dictionaries.

This method recursively traverses all dictionary definitions and replaces
$ref objects with the actual referenced dictionary content.

Returns:
List of errors encountered during resolution
"""
ref_errors = []

# Create a temporary copy of dictionaries for resolution
resolved = {}

for dict_name, dict_def in self.dictionaries.items():
resolved_def, errors = self._resolve_refs_in_object(dict_def, dict_name, set())
resolved[dict_name] = resolved_def
ref_errors.extend(errors)

self.dictionaries = resolved

return ref_errors

def _resolve_refs_in_object(self, obj: Any, dict_name: str, visited: Set[str]) -> tuple[Any, List[Dict]]:
"""Recursively resolve $ref values in an object.

Args:
obj: The object to resolve $ref values in
dict_name: The name of the dictionary being resolved
visited: Set of already visited paths (for cycle detection)

Returns:
Tuple of (resolved_object, list_of_errors)
"""
errors = []
if isinstance(obj, dict):
# Check if this is a $ref object
if "$ref" in obj:
ref_name = obj["$ref"]
if ref_name in visited:
errors.append({
"error": "circular_reference",
"error_id": "SCH-013",
"dict_name": dict_name,
"ref_name": ref_name,
"message": f"Circular reference detected: {ref_name}"
})
return obj, errors
elif ref_name not in self.dictionaries:
errors.append({
"error": "ref_not_found",
"error_id": "SCH-014",
"dict_name": dict_name,
"ref_name": ref_name,
"message": f"Referenced dictionary not found: {ref_name}"
})
return obj, errors
else:
# Resolve the reference - replace the entire object with the referenced content
visited.add(ref_name)
resolved, ref_errors = self._resolve_refs_in_object(self.dictionaries[ref_name], dict_name, visited)
visited.remove(ref_name)
errors.extend(ref_errors)
return resolved, errors

# Otherwise, recursively resolve all values in the dictionary
resolved = {}
for key, value in obj.items():
resolved_value, value_errors = self._resolve_refs_in_object(value, dict_name, visited)
resolved[key] = resolved_value
errors.extend(value_errors)
return resolved, errors

elif isinstance(obj, list):
# Recursively resolve all items in the list
resolved_items = []
for item in obj:
resolved_item, item_errors = self._resolve_refs_in_object(item, dict_name, visited)
resolved_items.append(resolved_item)
errors.extend(item_errors)
return resolved_items, errors
else:
# Primitive value, return as-is
return obj, errors

def _load_collection_configs(self) -> None:
"""Load collection configurations from the input folder.

Expand Down
11 changes: 1 addition & 10 deletions stage0_mongodb_api/managers/schema_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,12 @@ def render_schema(version_name: str, format: SchemaFormat, context: SchemaContex
@staticmethod
def _render(schema: Dict, format: SchemaFormat, enumerator_version: int, context: SchemaContext) -> Dict:
""" Recursively render a schema definition."""
# Handle $ref first - replace with referenced dictionary
if "$ref" in schema:
return SchemaRenderer._render(
context["dictionaries"][schema["$ref"]],
format,
enumerator_version,
context
)

# Handle primitive types
if "schema" in schema or "json_type" in schema:
return SchemaRenderer._render_primitive(schema, format)

# Handle complex types
logger.info(f"Rendering schema: {schema}")
logger.debug(f"Rendering schema: {schema}")
type_name = schema["type"]
if type_name == SchemaType.OBJECT.value:
return SchemaRenderer._render_object(schema, format, enumerator_version, context)
Expand Down
42 changes: 8 additions & 34 deletions stage0_mongodb_api/managers/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,6 @@ def _validate_complex_type_properties(prop_name: str, prop_def: Dict, context: V
}])
return type_errors

# Validate $ref types
if "$ref" in prop_def:
ref_errors = SchemaValidator._validate_ref_type(prop_name, prop_def["$ref"], context)
if ref_errors:
type_errors.extend(ref_errors)
return type_errors

# Validate required fields
type_errors.extend(SchemaValidator._validate_required_fields(prop_name, prop_def))
if type_errors:
Expand Down Expand Up @@ -350,19 +343,6 @@ def _validate_required_fields(prop_name: str, prop_def: Dict) -> List[Dict]:
})
return errors

@staticmethod
def _validate_ref_type(prop_name: str, ref_type: str, context: ValidationContext) -> List[Dict]:
"""Validate a $ref type reference."""
if ref_type not in context["dictionaries"]:
return [{
"error": "invalid_ref_type",
"error_id": "VLD-501",
"type": prop_name,
"ref": ref_type,
"message": f"Referenced type {ref_type} not found in dictionaries"
}]
return []

@staticmethod
def _validate_custom_type(prop_name: str, type_name: str, context: ValidationContext) -> List[Dict]:
"""Validate a custom type reference."""
Expand Down Expand Up @@ -484,20 +464,14 @@ def _validate_one_of_type(prop_name: str, one_of_def: Dict, context: ValidationC

# Validate each schema in the one_of definition
for schema_name, schema_def in one_of_def["schemas"].items():
# If schema is a $ref, validate the reference
if isinstance(schema_def, dict) and "$ref" in schema_def:
ref_errors = SchemaValidator._validate_ref_type(f"{prop_name}.{schema_name}", schema_def["$ref"], context)
if ref_errors:
errors.extend(ref_errors)
else:
# Otherwise validate as a complex type
errors.extend(SchemaValidator._validate_complex_type(
f"{prop_name}.{schema_name}",
schema_def,
context,
enumerator_version,
visited
))
# Validate as a complex type (all $ref objects will have been resolved during loading)
errors.extend(SchemaValidator._validate_complex_type(
f"{prop_name}.{schema_name}",
schema_def,
context,
enumerator_version,
visited
))

return errors

9 changes: 0 additions & 9 deletions tests/managers/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ def test_non_parsable(self):
manager = ConfigManager()
self.assertEqual(len(manager.load_errors), 1, f"Unexpected load errors {manager.load_errors}")

def test_validation_errors(self):
"""Test loading with validation errors"""
test_case_dir = os.path.join(self.test_cases_dir, "validation_errors")
self.config.INPUT_FOLDER = test_case_dir
manager = ConfigManager()
errors = manager.validate_configs()
self.assertEqual(len(manager.load_errors), 0, f"Unexpected load errors {manager.load_errors}")
self.assertEqual(len(errors), 6, f"Unexpected number of validation errors {errors}")

def test_load_test_data_bulk_write_error(self):
"""Test that _load_test_data properly handles bulk write errors."""
from stage0_py_utils.mongo_utils.mongo_io import TestDataLoadError
Expand Down
54 changes: 54 additions & 0 deletions tests/managers/test_ref_load_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import os
from unittest.mock import MagicMock, patch
from stage0_mongodb_api.managers.schema_manager import SchemaManager
from stage0_py_utils import Config


class TestRefLoadErrors(unittest.TestCase):
"""Test cases for $ref load errors during schema loading."""

def setUp(self):
"""Set up test environment."""
self.config = Config.get_instance()
self.test_cases_dir = os.path.join(os.path.dirname(__file__), "..", "test_cases")

@patch('stage0_py_utils.MongoIO.get_instance')
def test_ref_load_errors(self, mock_get_instance):
"""Test that $ref load errors are properly caught and reported."""
# Arrange
mock_get_instance.return_value = MagicMock()
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "ref_load_errors")

# Act
schema_manager = SchemaManager()

# Assert
# Check that we have load errors
self.assertGreater(len(schema_manager.load_errors), 0,
"Should have load errors for $ref issues")

# Check for specific error types
error_codes = [error.get('error_id') for error in schema_manager.load_errors]

# Should have SCH-013 (circular reference) and SCH-014 (missing reference)
self.assertIn('SCH-013', error_codes,
"Should have circular reference error (SCH-013)")
self.assertIn('SCH-014', error_codes,
"Should have missing reference error (SCH-014)")

# Verify error details
circular_error = next((e for e in schema_manager.load_errors if e.get('error_id') == 'SCH-013'), None)
missing_error = next((e for e in schema_manager.load_errors if e.get('error_id') == 'SCH-014'), None)

self.assertIsNotNone(circular_error, "Should have circular reference error")
self.assertEqual(circular_error['error'], 'circular_reference')
self.assertEqual(circular_error['ref_name'], 'circular_ref.1.0.0')

self.assertIsNotNone(missing_error, "Should have missing reference error")
self.assertEqual(missing_error['error'], 'ref_not_found')
self.assertEqual(missing_error['ref_name'], 'does_not_exist.1.0.0')


if __name__ == '__main__':
unittest.main()
18 changes: 18 additions & 0 deletions tests/managers/test_schema_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,23 @@ def test_load_errors(self, mock_get_instance):
self.assertEqual(missing_error_ids, set(), f"Missing error IDs: {missing_error_ids}")
self.assertEqual(extra_error_ids, set(), f"Extra error IDs: {extra_error_ids}")

@patch('stage0_py_utils.MongoIO.get_instance')
def test_ref_resolution_errors(self, mock_get_instance):
"""Test loading with $ref resolution errors."""
# Arrange
mock_get_instance.return_value = MagicMock()
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "validation_errors")

# Act
schema_manager = SchemaManager()

# Assert
expected_error_ids = {"SCH-014"} # Only missing reference, no circular reference in this test case
actual_error_ids = {error.get('error_id') for error in schema_manager.load_errors if 'error_id' in error}
missing_error_ids = expected_error_ids - actual_error_ids
extra_error_ids = actual_error_ids - expected_error_ids
self.assertEqual(missing_error_ids, set(), f"Missing error IDs: {missing_error_ids}")
self.assertEqual(extra_error_ids, set(), f"Extra error IDs: {extra_error_ids}")

if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/managers/test_schema_renders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,25 @@ def test_render_simple(self, mock_get_instance):
self.assertEqual(rendered_bson, expected_bson, f"BSON schema mismatch, rendered: {rendered_bson}")
self.assertEqual(rendered_json, expected_json, f"JSON schema mismatch, rendered: {rendered_json}")

@patch('stage0_py_utils.MongoIO.get_instance')
def test_render_nested_refs(self, mock_get_instance):
"""Test rendering of nested $refs."""
# Arrange
mock_get_instance.return_value = MagicMock()
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "complex_refs")
schema_manager = SchemaManager()
version_name = "workshop.1.0.0.1"

# Act
rendered_bson = schema_manager.render_one(version_name, SchemaFormat.BSON)
rendered_json = schema_manager.render_one(version_name, SchemaFormat.JSON)

# Assert
expected_bson = self._load_bson(version_name)
expected_json = self._load_json(version_name)
self.assertEqual(rendered_bson, expected_bson, f"BSON schema mismatch, rendered: {rendered_bson}")
self.assertEqual(rendered_json, expected_json, f"JSON schema mismatch, rendered: {rendered_json}")

@patch('stage0_py_utils.MongoIO.get_instance')
def test_render_organization(self, mock_get_instance):
"""Test rendering with complex custom types."""
Expand Down
16 changes: 15 additions & 1 deletion tests/managers/test_schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def test_validate_large_sample(self, mock_get_instance):
self.assertEqual(schema_manager.load_errors, [])
self.assertEqual(errors, [])

@patch('stage0_py_utils.MongoIO.get_instance')
def test_validate_complex_refs(self, mock_get_instance):
"""Test validation of complex nested $refs."""
# Arrange
self.config.INPUT_FOLDER = os.path.join(self.test_cases_dir, "complex_refs")
mock_get_instance.return_value = MagicMock()
schema_manager = SchemaManager()

# Act
errors = schema_manager.validate_schema()

# Assert
self.assertEqual(schema_manager.load_errors, [])
self.assertEqual(errors, [])

@patch('stage0_py_utils.MongoIO.get_instance')
def test_validation_errors(self, mock_get_instance):
"""Test validation with all validation errors."""
Expand All @@ -82,7 +97,6 @@ def test_validation_errors(self, mock_get_instance):
"VLD-201", "VLD-202", "VLD-203", "VLD-204", # Primitive type validation errors
"VLD-301", # Complex type basic validation
"VLD-401", # Required fields validation
"VLD-501", # Reference type validation
"VLD-601", # Custom type validation
"VLD-701", # Object type validation
"VLD-801", # Array type validation
Expand Down
5 changes: 5 additions & 0 deletions tests/test_cases/complex_refs/collections/workshop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
title: Workshop Collection
description: A record of a workshop
name: user
versions:
- version: "1.0.0.1"
Loading
Loading