Skip to content

Commit e68428d

Browse files
authored
Support enforcing json schema in supported AI model APIs (#1133)
- Trigger Gemini 2.0 Flash doesn't always follow JSON schema in research prompt - Details - Use json schema to enforce generate online queries format - Use json schema to enforce research mode tool pick format - Support constraining Gemini model output to specified response schema - Support constraining OpenAI model output to specified response schema - Only enforce json output in supported AI model APIs - Simplify OpenAI reasoning model specific arguments to OpenAI API
2 parents 9b6d626 + a5627ef commit e68428d

File tree

7 files changed

+110
-45
lines changed

7 files changed

+110
-45
lines changed

src/khoj/processor/conversation/google/gemini_chat.py

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def gemini_send_message_to_model(
121121
api_key,
122122
model,
123123
response_type="text",
124+
response_schema=None,
124125
temperature=0.6,
125126
model_kwargs=None,
126127
tracer={},
@@ -135,6 +136,7 @@ def gemini_send_message_to_model(
135136
# This caused unwanted behavior and terminates response early for gemini 1.5 series. Monitor for flakiness with 2.0 series.
136137
if response_type == "json_object" and model in ["gemini-2.0-flash"]:
137138
model_kwargs["response_mime_type"] = "application/json"
139+
model_kwargs["response_schema"] = response_schema
138140

139141
# Get Response from Gemini
140142
return gemini_completion_with_backoff(

src/khoj/processor/conversation/google/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def gemini_completion_with_backoff(
6666
max_output_tokens=MAX_OUTPUT_TOKENS_GEMINI,
6767
safety_settings=SAFETY_SETTINGS,
6868
response_mime_type=model_kwargs.get("response_mime_type", "text/plain") if model_kwargs else "text/plain",
69+
response_schema=model_kwargs.get("response_schema", None) if model_kwargs else None,
6970
)
7071

7172
formatted_messages = [gtypes.Content(role=message.role, parts=message.content) for message in messages]

src/khoj/processor/conversation/openai/gpt.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from khoj.processor.conversation.openai.utils import (
1111
chat_completion_with_backoff,
1212
completion_with_backoff,
13+
get_openai_api_json_support,
1314
)
1415
from khoj.processor.conversation.utils import (
16+
JsonSupport,
1517
clean_json,
1618
construct_structured_message,
1719
generate_chatml_messages_with_context,
@@ -119,20 +121,34 @@ def extract_questions(
119121

120122

121123
def send_message_to_model(
122-
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
124+
messages,
125+
api_key,
126+
model,
127+
response_type="text",
128+
response_schema=None,
129+
api_base_url=None,
130+
temperature=0,
131+
tracer: dict = {},
123132
):
124133
"""
125134
Send message to model
126135
"""
127136

137+
model_kwargs = {}
138+
json_support = get_openai_api_json_support(model, api_base_url)
139+
if response_schema and json_support == JsonSupport.SCHEMA:
140+
model_kwargs["response_format"] = response_schema
141+
elif response_type == "json_object" and json_support == JsonSupport.OBJECT:
142+
model_kwargs["response_format"] = {"type": response_type}
143+
128144
# Get Response from GPT
129145
return completion_with_backoff(
130146
messages=messages,
131147
model_name=model,
132148
openai_api_key=api_key,
133149
temperature=temperature,
134150
api_base_url=api_base_url,
135-
model_kwargs={"response_format": {"type": response_type}},
151+
model_kwargs=model_kwargs,
136152
tracer=tracer,
137153
)
138154

src/khoj/processor/conversation/openai/utils.py

+31-41
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from threading import Thread
44
from typing import Dict, List
5+
from urllib.parse import urlparse
56

67
import openai
78
from openai.types.chat.chat_completion import ChatCompletion
@@ -16,6 +17,7 @@
1617
)
1718

1819
from khoj.processor.conversation.utils import (
20+
JsonSupport,
1921
ThreadedGenerator,
2022
commit_conversation_trace,
2123
)
@@ -60,45 +62,29 @@ def completion_with_backoff(
6062

6163
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
6264

63-
# Update request parameters for compatability with o1 model series
64-
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
65-
stream = True
66-
model_kwargs["stream_options"] = {"include_usage": True}
67-
if model_name == "o1":
68-
temperature = 1
69-
stream = False
70-
model_kwargs.pop("stream_options", None)
71-
elif model_name.startswith("o1"):
72-
temperature = 1
73-
model_kwargs.pop("response_format", None)
74-
elif model_name.startswith("o3-"):
65+
# Tune reasoning models arguments
66+
if model_name.startswith("o1") or model_name.startswith("o3"):
7567
temperature = 1
68+
model_kwargs["reasoning_effort"] = "medium"
7669

70+
model_kwargs["stream_options"] = {"include_usage": True}
7771
if os.getenv("KHOJ_LLM_SEED"):
7872
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
7973

80-
chat: ChatCompletion | openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
74+
aggregated_response = ""
75+
with client.beta.chat.completions.stream(
8176
messages=formatted_messages, # type: ignore
82-
model=model_name, # type: ignore
83-
stream=stream,
77+
model=model_name,
8478
temperature=temperature,
8579
timeout=20,
8680
**model_kwargs,
87-
)
88-
89-
aggregated_response = ""
90-
if not stream:
91-
chunk = chat
92-
aggregated_response = chunk.choices[0].message.content
93-
else:
81+
) as chat:
9482
for chunk in chat:
95-
if len(chunk.choices) == 0:
83+
if chunk.type == "error":
84+
logger.error(f"Openai api response error: {chunk.error}", exc_info=True)
9685
continue
97-
delta_chunk = chunk.choices[0].delta # type: ignore
98-
if isinstance(delta_chunk, str):
99-
aggregated_response += delta_chunk
100-
elif delta_chunk.content:
101-
aggregated_response += delta_chunk.content
86+
elif chunk.type == "content.delta":
87+
aggregated_response += chunk.delta
10288

10389
# Calculate cost of chat
10490
input_tokens = chunk.usage.prompt_tokens if hasattr(chunk, "usage") and chunk.usage else 0
@@ -172,20 +158,13 @@ def llm_thread(
172158

173159
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
174160

175-
# Update request parameters for compatability with o1 model series
176-
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
177-
stream = True
178-
model_kwargs["stream_options"] = {"include_usage": True}
179-
if model_name == "o1":
161+
# Tune reasoning models arguments
162+
if model_name.startswith("o1"):
180163
temperature = 1
181-
stream = False
182-
model_kwargs.pop("stream_options", None)
183-
elif model_name.startswith("o1-"):
164+
elif model_name.startswith("o3"):
184165
temperature = 1
185-
model_kwargs.pop("response_format", None)
186-
elif model_name.startswith("o3-"):
187-
temperature = 1
188-
# Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
166+
# Get the first system message and add the string `Formatting re-enabled` to it.
167+
# See https://platform.openai.com/docs/guides/reasoning-best-practices
189168
if len(formatted_messages) > 0:
190169
system_messages = [
191170
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
@@ -195,7 +174,6 @@ def llm_thread(
195174
formatted_messages[first_system_message_index][
196175
"content"
197176
] = f"{first_system_message} Formatting re-enabled"
198-
199177
elif model_name.startswith("deepseek-reasoner"):
200178
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
201179
# The first message should always be a user message (except system message).
@@ -210,6 +188,8 @@ def llm_thread(
210188

211189
formatted_messages = updated_messages
212190

191+
stream = True
192+
model_kwargs["stream_options"] = {"include_usage": True}
213193
if os.getenv("KHOJ_LLM_SEED"):
214194
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
215195

@@ -258,3 +238,13 @@ def llm_thread(
258238
logger.error(f"Error in llm_thread: {e}", exc_info=True)
259239
finally:
260240
g.close()
241+
242+
243+
def get_openai_api_json_support(model_name: str, api_base_url: str = None) -> JsonSupport:
244+
if model_name.startswith("deepseek-reasoner"):
245+
return JsonSupport.NONE
246+
if api_base_url:
247+
host = urlparse(api_base_url).hostname
248+
if host and host.endswith(".ai.azure.com"):
249+
return JsonSupport.OBJECT
250+
return JsonSupport.SCHEMA

src/khoj/processor/conversation/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,9 @@ def safe_serialize(content: Any) -> str:
878878
return str(content)
879879

880880
return "\n".join([f"{json.dumps(safe_serialize(message.content))[:max_length]}..." for message in messages])
881+
882+
883+
class JsonSupport(int, Enum):
884+
NONE = 0
885+
OBJECT = 1
886+
SCHEMA = 2

src/khoj/routers/helpers.py

+10
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,15 @@ async def generate_online_subqueries(
540540

541541
agent_chat_model = agent.chat_model if agent else None
542542

543+
class OnlineQueries(BaseModel):
544+
queries: List[str]
545+
543546
with timer("Chat actor: Generate online search subqueries", logger):
544547
response = await send_message_to_model_wrapper(
545548
online_queries_prompt,
546549
query_images=query_images,
547550
response_type="json_object",
551+
response_schema=OnlineQueries,
548552
user=user,
549553
query_files=query_files,
550554
agent_chat_model=agent_chat_model,
@@ -1129,6 +1133,7 @@ async def send_message_to_model_wrapper(
11291133
query: str,
11301134
system_message: str = "",
11311135
response_type: str = "text",
1136+
response_schema: BaseModel = None,
11321137
deepthought: bool = False,
11331138
user: KhojUser = None,
11341139
query_images: List[str] = None,
@@ -1209,6 +1214,7 @@ async def send_message_to_model_wrapper(
12091214
api_key=api_key,
12101215
model=chat_model_name,
12111216
response_type=response_type,
1217+
response_schema=response_schema,
12121218
api_base_url=api_base_url,
12131219
tracer=tracer,
12141220
)
@@ -1255,6 +1261,7 @@ async def send_message_to_model_wrapper(
12551261
api_key=api_key,
12561262
model=chat_model_name,
12571263
response_type=response_type,
1264+
response_schema=response_schema,
12581265
tracer=tracer,
12591266
)
12601267
else:
@@ -1265,6 +1272,7 @@ def send_message_to_model_wrapper_sync(
12651272
message: str,
12661273
system_message: str = "",
12671274
response_type: str = "text",
1275+
response_schema: BaseModel = None,
12681276
user: KhojUser = None,
12691277
query_images: List[str] = None,
12701278
query_files: str = "",
@@ -1326,6 +1334,7 @@ def send_message_to_model_wrapper_sync(
13261334
api_base_url=api_base_url,
13271335
model=chat_model_name,
13281336
response_type=response_type,
1337+
response_schema=response_schema,
13291338
tracer=tracer,
13301339
)
13311340

@@ -1370,6 +1379,7 @@ def send_message_to_model_wrapper_sync(
13701379
api_key=api_key,
13711380
model=chat_model_name,
13721381
response_type=response_type,
1382+
response_schema=response_schema,
13731383
tracer=tracer,
13741384
)
13751385
else:

src/khoj/routers/research.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
import os
33
from datetime import datetime
4-
from typing import Callable, Dict, List, Optional
4+
from enum import Enum
5+
from typing import Callable, Dict, List, Optional, Type
56

67
import yaml
78
from fastapi import Request
9+
from pydantic import BaseModel, Field
810

911
from khoj.database.adapters import EntryAdapters
1012
from khoj.database.models import Agent, KhojUser
@@ -36,6 +38,40 @@
3638
logger = logging.getLogger(__name__)
3739

3840

41+
class PlanningResponse(BaseModel):
42+
"""
43+
Schema for the response from planning agent when deciding the next tool to pick.
44+
The tool field is dynamically validated based on available tools.
45+
"""
46+
47+
scratchpad: str = Field(..., description="Reasoning about which tool to use next")
48+
query: str = Field(..., description="Detailed query for the selected tool")
49+
50+
class Config:
51+
arbitrary_types_allowed = True
52+
53+
@classmethod
54+
def create_model_with_enum(cls: Type["PlanningResponse"], tool_options: dict) -> Type["PlanningResponse"]:
55+
"""
56+
Factory method that creates a customized PlanningResponse model
57+
with a properly typed tool field based on available tools.
58+
59+
Args:
60+
tool_options: Dictionary mapping tool names to values
61+
62+
Returns:
63+
A customized PlanningResponse class
64+
"""
65+
# Create dynamic enum from tool options
66+
tool_enum = Enum("ToolEnum", tool_options) # type: ignore
67+
68+
# Create and return a customized response model with the enum
69+
class PlanningResponseWithTool(PlanningResponse):
70+
tool: tool_enum = Field(..., description="Name of the tool to use")
71+
72+
return PlanningResponseWithTool
73+
74+
3975
async def apick_next_tool(
4076
query: str,
4177
conversation_history: dict,
@@ -61,10 +97,13 @@ async def apick_next_tool(
6197
# Skip showing Notes tool as an option if user has no entries
6298
if tool == ConversationCommand.Notes and not user_has_entries:
6399
continue
64-
tool_options[tool.value] = description
65100
if len(agent_tools) == 0 or tool.value in agent_tools:
101+
tool_options[tool.name] = tool.value
66102
tool_options_str += f'- "{tool.value}": "{description}"\n'
67103

104+
# Create planning reponse model with dynamically populated tool enum class
105+
planning_response_model = PlanningResponse.create_model_with_enum(tool_options)
106+
68107
# Construct chat history with user and iteration history with researcher agent for context
69108
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
70109
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
@@ -96,6 +135,7 @@ async def apick_next_tool(
96135
query=query,
97136
context=function_planning_prompt,
98137
response_type="json_object",
138+
response_schema=planning_response_model,
99139
deepthought=True,
100140
user=user,
101141
query_images=query_images,

0 commit comments

Comments
 (0)