2
2
import os
3
3
from threading import Thread
4
4
from typing import Dict , List
5
+ from urllib .parse import urlparse
5
6
6
7
import openai
7
8
from openai .types .chat .chat_completion import ChatCompletion
16
17
)
17
18
18
19
from khoj .processor .conversation .utils import (
20
+ JsonSupport ,
19
21
ThreadedGenerator ,
20
22
commit_conversation_trace ,
21
23
)
@@ -60,45 +62,29 @@ def completion_with_backoff(
60
62
61
63
formatted_messages = [{"role" : message .role , "content" : message .content } for message in messages ]
62
64
63
- # Update request parameters for compatability with o1 model series
64
- # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
65
- stream = True
66
- model_kwargs ["stream_options" ] = {"include_usage" : True }
67
- if model_name == "o1" :
68
- temperature = 1
69
- stream = False
70
- model_kwargs .pop ("stream_options" , None )
71
- elif model_name .startswith ("o1" ):
72
- temperature = 1
73
- model_kwargs .pop ("response_format" , None )
74
- elif model_name .startswith ("o3-" ):
65
+ # Tune reasoning models arguments
66
+ if model_name .startswith ("o1" ) or model_name .startswith ("o3" ):
75
67
temperature = 1
68
+ model_kwargs ["reasoning_effort" ] = "medium"
76
69
70
+ model_kwargs ["stream_options" ] = {"include_usage" : True }
77
71
if os .getenv ("KHOJ_LLM_SEED" ):
78
72
model_kwargs ["seed" ] = int (os .getenv ("KHOJ_LLM_SEED" ))
79
73
80
- chat : ChatCompletion | openai .Stream [ChatCompletionChunk ] = client .chat .completions .create (
74
+ aggregated_response = ""
75
+ with client .beta .chat .completions .stream (
81
76
messages = formatted_messages , # type: ignore
82
- model = model_name , # type: ignore
83
- stream = stream ,
77
+ model = model_name ,
84
78
temperature = temperature ,
85
79
timeout = 20 ,
86
80
** model_kwargs ,
87
- )
88
-
89
- aggregated_response = ""
90
- if not stream :
91
- chunk = chat
92
- aggregated_response = chunk .choices [0 ].message .content
93
- else :
81
+ ) as chat :
94
82
for chunk in chat :
95
- if len (chunk .choices ) == 0 :
83
+ if chunk .type == "error" :
84
+ logger .error (f"Openai api response error: { chunk .error } " , exc_info = True )
96
85
continue
97
- delta_chunk = chunk .choices [0 ].delta # type: ignore
98
- if isinstance (delta_chunk , str ):
99
- aggregated_response += delta_chunk
100
- elif delta_chunk .content :
101
- aggregated_response += delta_chunk .content
86
+ elif chunk .type == "content.delta" :
87
+ aggregated_response += chunk .delta
102
88
103
89
# Calculate cost of chat
104
90
input_tokens = chunk .usage .prompt_tokens if hasattr (chunk , "usage" ) and chunk .usage else 0
@@ -172,20 +158,13 @@ def llm_thread(
172
158
173
159
formatted_messages = [{"role" : message .role , "content" : message .content } for message in messages ]
174
160
175
- # Update request parameters for compatability with o1 model series
176
- # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
177
- stream = True
178
- model_kwargs ["stream_options" ] = {"include_usage" : True }
179
- if model_name == "o1" :
161
+ # Tune reasoning models arguments
162
+ if model_name .startswith ("o1" ):
180
163
temperature = 1
181
- stream = False
182
- model_kwargs .pop ("stream_options" , None )
183
- elif model_name .startswith ("o1-" ):
164
+ elif model_name .startswith ("o3" ):
184
165
temperature = 1
185
- model_kwargs .pop ("response_format" , None )
186
- elif model_name .startswith ("o3-" ):
187
- temperature = 1
188
- # Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
166
+ # Get the first system message and add the string `Formatting re-enabled` to it.
167
+ # See https://platform.openai.com/docs/guides/reasoning-best-practices
189
168
if len (formatted_messages ) > 0 :
190
169
system_messages = [
191
170
(i , message ) for i , message in enumerate (formatted_messages ) if message ["role" ] == "system"
@@ -195,7 +174,6 @@ def llm_thread(
195
174
formatted_messages [first_system_message_index ][
196
175
"content"
197
176
] = f"{ first_system_message } Formatting re-enabled"
198
-
199
177
elif model_name .startswith ("deepseek-reasoner" ):
200
178
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
201
179
# The first message should always be a user message (except system message).
@@ -210,6 +188,8 @@ def llm_thread(
210
188
211
189
formatted_messages = updated_messages
212
190
191
+ stream = True
192
+ model_kwargs ["stream_options" ] = {"include_usage" : True }
213
193
if os .getenv ("KHOJ_LLM_SEED" ):
214
194
model_kwargs ["seed" ] = int (os .getenv ("KHOJ_LLM_SEED" ))
215
195
@@ -258,3 +238,13 @@ def llm_thread(
258
238
logger .error (f"Error in llm_thread: { e } " , exc_info = True )
259
239
finally :
260
240
g .close ()
241
+
242
+
243
+ def get_openai_api_json_support (model_name : str , api_base_url : str = None ) -> JsonSupport :
244
+ if model_name .startswith ("deepseek-reasoner" ):
245
+ return JsonSupport .NONE
246
+ if api_base_url :
247
+ host = urlparse (api_base_url ).hostname
248
+ if host and host .endswith (".ai.azure.com" ):
249
+ return JsonSupport .OBJECT
250
+ return JsonSupport .SCHEMA
0 commit comments