Skip to content

Commit 33665de

Browse files
committed
Dynamically set default agent chat model to server > user > first chat model
Previously the chat model associated with the default agent was always the first chat model populated on the server. This doesn't match behavior of the rest of the system, where the server chat settings is preferred over the user chat settings over the first chat model. This change brings the default agent's chat model in line with the preference order used in the reset of the system.
1 parent 1eb0920 commit 33665de

File tree

4 files changed

+45
-13
lines changed

4 files changed

+45
-13
lines changed

src/khoj/database/adapters/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,30 @@ def create_default_agent(user: KhojUser):
798798
async def aget_default_agent():
799799
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
800800

801+
@staticmethod
802+
def get_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
803+
"""
804+
Gets the appropriate chat model for an agent.
805+
For the default agent, it dynamically determines the model based on user/server settings.
806+
For other agents, it returns their statically assigned chat model.
807+
Requires the user context to determine the correct default model.
808+
"""
809+
if agent.slug == AgentAdapters.DEFAULT_AGENT_SLUG:
810+
# Dynamically get the default model based on context
811+
return ConversationAdapters.get_default_chat_model(user)
812+
elif agent.chat_model:
813+
# Return the model assigned directly to the specific agent
814+
# Ensure the related object is loaded if necessary (prefetching is recommended)
815+
return agent.chat_model
816+
else:
817+
# Fallback if agent has no unset chat_model. For example if chat_model associated with agent was deleted.
818+
logger.warning(f"Agent {agent.slug} has no chat_model or agent is None, returning overall default.")
819+
return ConversationAdapters.get_default_chat_model(user)
820+
821+
@staticmethod
822+
async def aget_agent_chat_model(agent: Agent, user: Optional[KhojUser]) -> Optional[ChatModel]:
823+
return await sync_to_async(AgentAdapters.get_agent_chat_model)(agent, user)
824+
801825
@staticmethod
802826
@arequire_valid_user
803827
async def aupdate_agent(

src/khoj/routers/api_agents.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ async def all_agents(
6262
for agent in agents:
6363
files = agent.fileobject_set.all()
6464
file_names = [file.file_name for file in files]
65+
agent_chat_model = await AgentAdapters.aget_agent_chat_model(default_agent, user)
6566
agent_packet = {
6667
"slug": agent.slug,
6768
"name": agent.name,
@@ -71,7 +72,7 @@ async def all_agents(
7172
"color": agent.style_color,
7273
"icon": agent.style_icon,
7374
"privacy_level": agent.privacy_level,
74-
"chat_model": agent.chat_model.name,
75+
"chat_model": agent_chat_model.name,
7576
"files": file_names,
7677
"input_tools": agent.input_tools,
7778
"output_modes": agent.output_modes,
@@ -125,6 +126,7 @@ async def get_agent_by_conversation(
125126
agent = await AgentAdapters.aget_default_agent()
126127

127128
has_files = agent.fileobject_set.exists()
129+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
128130

129131
agents_packet = {
130132
"slug": agent.slug,
@@ -194,6 +196,8 @@ async def get_agent(
194196

195197
files = agent.fileobject_set.all()
196198
file_names = [file.file_name for file in files]
199+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
200+
197201
agents_packet = {
198202
"slug": agent.slug,
199203
"name": agent.name,
@@ -265,6 +269,7 @@ async def update_hidden_agent(
265269
output_modes=body.output_modes,
266270
existing_agent=selected_agent,
267271
)
272+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
268273

269274
agents_packet = {
270275
"slug": agent.slug,
@@ -320,6 +325,7 @@ async def create_hidden_agent(
320325
output_modes=body.output_modes,
321326
existing_agent=None,
322327
)
328+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
323329

324330
conversation.agent = agent
325331
await conversation.asave()
@@ -374,6 +380,7 @@ async def create_agent(
374380
body.slug,
375381
body.is_hidden,
376382
)
383+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
377384

378385
agents_packet = {
379386
"slug": agent.slug,
@@ -439,6 +446,7 @@ async def update_agent(
439446
body.output_modes,
440447
body.slug,
441448
)
449+
agent.chat_model = await AgentAdapters.aget_agent_chat_model(agent, user)
442450

443451
agents_packet = {
444452
"slug": agent.slug,

src/khoj/routers/helpers.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ async def aget_data_sources_and_output_format(
403403
personality_context=personality_context,
404404
)
405405

406-
agent_chat_model = agent.chat_model if agent else None
406+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
407407

408408
class PickTools(BaseModel):
409409
source: List[str] = Field(..., min_items=1)
@@ -492,7 +492,7 @@ async def infer_webpage_urls(
492492
personality_context=personality_context,
493493
)
494494

495-
agent_chat_model = agent.chat_model if agent else None
495+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
496496

497497
class WebpageUrls(BaseModel):
498498
links: List[str] = Field(..., min_items=1, max_items=max_webpages)
@@ -557,7 +557,7 @@ async def generate_online_subqueries(
557557
personality_context=personality_context,
558558
)
559559

560-
agent_chat_model = agent.chat_model if agent else None
560+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
561561

562562
class OnlineQueries(BaseModel):
563563
queries: List[str] = Field(..., min_items=1, max_items=max_queries)
@@ -666,7 +666,7 @@ async def extract_relevant_info(
666666
personality_context=personality_context,
667667
)
668668

669-
agent_chat_model = agent.chat_model if agent else None
669+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
670670

671671
response = await send_message_to_model_wrapper(
672672
extract_relevant_information,
@@ -707,7 +707,7 @@ async def extract_relevant_summary(
707707
personality_context=personality_context,
708708
)
709709

710-
agent_chat_model = agent.chat_model if agent else None
710+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
711711

712712
with timer("Chat actor: Extract relevant information from data", logger):
713713
response = await send_message_to_model_wrapper(
@@ -878,7 +878,7 @@ async def generate_better_diagram_description(
878878
personality_context=personality_context,
879879
)
880880

881-
agent_chat_model = agent.chat_model if agent else None
881+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
882882

883883
with timer("Chat actor: Generate better diagram description", logger):
884884
response = await send_message_to_model_wrapper(
@@ -911,7 +911,7 @@ async def generate_excalidraw_diagram_from_description(
911911
query=q,
912912
)
913913

914-
agent_chat_model = agent.chat_model if agent else None
914+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
915915

916916
with timer("Chat actor: Generate excalidraw diagram", logger):
917917
raw_response = await send_message_to_model_wrapper(
@@ -1029,7 +1029,7 @@ async def generate_better_mermaidjs_diagram_description(
10291029
personality_context=personality_context,
10301030
)
10311031

1032-
agent_chat_model = agent.chat_model if agent else None
1032+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
10331033

10341034
with timer("Chat actor: Generate better Mermaid.js diagram description", logger):
10351035
response = await send_message_to_model_wrapper(
@@ -1062,7 +1062,7 @@ async def generate_mermaidjs_diagram_from_description(
10621062
query=q,
10631063
)
10641064

1065-
agent_chat_model = agent.chat_model if agent else None
1065+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
10661066

10671067
with timer("Chat actor: Generate Mermaid.js diagram", logger):
10681068
raw_response = await send_message_to_model_wrapper(
@@ -1132,7 +1132,7 @@ async def generate_better_image_prompt(
11321132
personality_context=personality_context,
11331133
)
11341134

1135-
agent_chat_model = agent.chat_model if agent else None
1135+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
11361136

11371137
with timer("Chat actor: Generate contextual image prompt", logger):
11381138
response = await send_message_to_model_wrapper(

src/khoj/routers/research.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi import Request
99
from pydantic import BaseModel, Field
1010

11-
from khoj.database.adapters import EntryAdapters
11+
from khoj.database.adapters import AgentAdapters, EntryAdapters
1212
from khoj.database.models import Agent, KhojUser
1313
from khoj.processor.conversation import prompts
1414
from khoj.processor.conversation.utils import (
@@ -116,7 +116,7 @@ async def apick_next_tool(
116116

117117
today = datetime.today()
118118
location_data = f"{location}" if location else "Unknown"
119-
agent_chat_model = agent.chat_model if agent else None
119+
agent_chat_model = AgentAdapters.get_agent_chat_model(agent, user) if agent else None
120120
personality_context = (
121121
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
122122
)

0 commit comments

Comments
 (0)