Skip to content

Commit e386b48

Browse files
Inline $ref and $defs in JSON schema in with_structured_output (#76)
Added resolve_schema_refs to OCIUtils to inline $ref and $defs in JSON schema, as OCI Generative AI does not support these features. Updated ChatOCIGenAI to use this utility before passing schemas to the provider. Added a unit test to verify correct handling of nested $ref in structured output schemas.
1 parent d10447c commit e386b48

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,31 @@ def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
118118
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
119119
)
120120

121+
@staticmethod
122+
def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
123+
"""
124+
OCI Generative AI doesn't support $ref and $defs, so we inline all references.
125+
"""
126+
defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs
127+
128+
def resolve(obj: Any) -> Any:
129+
if isinstance(obj, dict):
130+
if "$ref" in obj:
131+
ref = obj["$ref"]
132+
if ref.startswith("#/$defs/"):
133+
key = ref.split("/")[-1]
134+
return resolve(defs.get(key, obj))
135+
return obj # Cannot resolve $ref, return unchanged
136+
return {k: resolve(v) for k, v in obj.items()}
137+
elif isinstance(obj, list):
138+
return [resolve(item) for item in obj]
139+
return obj
140+
141+
resolved = resolve(schema)
142+
if isinstance(resolved, dict):
143+
resolved.pop("$defs", None)
144+
return resolved
145+
121146

122147
class Provider(ABC):
123148
"""Abstract base class for OCI Generative AI providers."""
@@ -1371,6 +1396,9 @@ def with_structured_output(
13711396
else schema # type: ignore[assignment]
13721397
)
13731398

1399+
# Resolve $ref references as OCI doesn't support $ref and $defs
1400+
json_schema_dict = OCIUtils.resolve_schema_refs(json_schema_dict)
1401+
13741402
response_json_schema = self._provider.oci_response_json_schema(
13751403
name=json_schema_dict.get("title", "response"),
13761404
description=json_schema_dict.get("description", ""),

libs/oci/tests/unit_tests/chat_models/test_response_format.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def test_with_structured_output_json_schema():
111111
oci_gen_ai_client = MagicMock()
112112
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
113113

114-
# This should not raise TypeError anymore
115114
from pydantic import BaseModel
116115

117116
class TestSchema(BaseModel):
@@ -126,6 +125,36 @@ class TestSchema(BaseModel):
126125
assert structured_llm is not None
127126

128127

128+
@pytest.mark.requires("oci")
129+
def test_with_structured_output_json_schema_nested_refs():
130+
"""Test with_structured_output with json_schema method and nested refs."""
131+
oci_gen_ai_client = MagicMock()
132+
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
133+
134+
from enum import Enum
135+
from typing import List
136+
137+
from pydantic import BaseModel
138+
139+
class Color(Enum):
140+
RED = "RED"
141+
BLUE = "BLUE"
142+
GREEN = "GREEN"
143+
144+
class Item(BaseModel):
145+
name: str
146+
color: Color # Creates $ref to Color
147+
148+
class Response(BaseModel):
149+
message: str
150+
items: List[Item] # Array with $ref inside
151+
152+
structured_llm = llm.with_structured_output(schema=Response, method="json_schema")
153+
154+
# The structured LLM should be created without errors
155+
assert structured_llm is not None
156+
157+
129158
@pytest.mark.requires("oci")
130159
def test_response_format_json_schema_object():
131160
"""Test response_format with JsonSchemaResponseFormat object."""

0 commit comments

Comments
 (0)