17
17
commit_conversation_trace ,
18
18
get_image_from_url ,
19
19
)
20
- from khoj .utils import state
21
20
from khoj .utils .helpers import (
22
21
get_chat_usage_metrics ,
23
- in_debug_mode ,
24
22
is_none_or_empty ,
25
23
is_promptrace_enabled ,
26
24
)
30
28
anthropic_clients : Dict [str , anthropic .Anthropic ] = {}
31
29
32
30
33
- DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
31
+ DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
32
+ MAX_REASONING_TOKENS_ANTHROPIC = 12000
34
33
35
34
36
35
@retry (
42
41
def anthropic_completion_with_backoff (
43
42
messages ,
44
43
system_prompt ,
45
- model_name ,
44
+ model_name : str ,
46
45
temperature = 0 ,
47
46
api_key = None ,
48
47
model_kwargs = None ,
49
48
max_tokens = None ,
50
49
response_type = "text" ,
50
+ deepthought = False ,
51
51
tracer = {},
52
52
) -> str :
53
53
if api_key not in anthropic_clients :
@@ -57,18 +57,24 @@ def anthropic_completion_with_backoff(
57
57
client = anthropic_clients [api_key ]
58
58
59
59
formatted_messages = [{"role" : message .role , "content" : message .content } for message in messages ]
60
- if response_type == "json_object" :
61
- # Prefill model response with '{' to make it output a valid JSON object
60
+ aggregated_response = ""
61
+ if response_type == "json_object" and not deepthought :
62
+ # Prefill model response with '{' to make it output a valid JSON object. Not supported with extended thinking.
62
63
formatted_messages += [{"role" : "assistant" , "content" : "{" }]
63
-
64
- aggregated_response = "{" if response_type == "json_object" else ""
65
- max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
64
+ aggregated_response += "{"
66
65
67
66
final_message = None
68
67
model_kwargs = model_kwargs or dict ()
69
68
if system_prompt :
70
69
model_kwargs ["system" ] = system_prompt
71
70
71
+ max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
72
+ if deepthought and model_name .startswith ("claude-3-7" ):
73
+ model_kwargs ["thinking" ] = {"type" : "enabled" , "budget_tokens" : MAX_REASONING_TOKENS_ANTHROPIC }
74
+ max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
75
+ # Temperature control not supported when using extended thinking
76
+ temperature = 1.0
77
+
72
78
with client .messages .stream (
73
79
messages = formatted_messages ,
74
80
model = model_name , # type: ignore
@@ -111,20 +117,41 @@ def anthropic_chat_completion_with_backoff(
111
117
system_prompt ,
112
118
max_prompt_size = None ,
113
119
completion_func = None ,
120
+ deepthought = False ,
114
121
model_kwargs = None ,
115
122
tracer = {},
116
123
):
117
124
g = ThreadedGenerator (compiled_references , online_results , completion_func = completion_func )
118
125
t = Thread (
119
126
target = anthropic_llm_thread ,
120
- args = (g , messages , system_prompt , model_name , temperature , api_key , max_prompt_size , model_kwargs , tracer ),
127
+ args = (
128
+ g ,
129
+ messages ,
130
+ system_prompt ,
131
+ model_name ,
132
+ temperature ,
133
+ api_key ,
134
+ max_prompt_size ,
135
+ deepthought ,
136
+ model_kwargs ,
137
+ tracer ,
138
+ ),
121
139
)
122
140
t .start ()
123
141
return g
124
142
125
143
126
144
def anthropic_llm_thread (
127
- g , messages , system_prompt , model_name , temperature , api_key , max_prompt_size = None , model_kwargs = None , tracer = {}
145
+ g ,
146
+ messages ,
147
+ system_prompt ,
148
+ model_name ,
149
+ temperature ,
150
+ api_key ,
151
+ max_prompt_size = None ,
152
+ deepthought = False ,
153
+ model_kwargs = None ,
154
+ tracer = {},
128
155
):
129
156
try :
130
157
if api_key not in anthropic_clients :
@@ -133,6 +160,14 @@ def anthropic_llm_thread(
133
160
else :
134
161
client : anthropic .Anthropic = anthropic_clients [api_key ]
135
162
163
+ model_kwargs = model_kwargs or dict ()
164
+ max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
165
+ if deepthought and model_name .startswith ("claude-3-7" ):
166
+ model_kwargs ["thinking" ] = {"type" : "enabled" , "budget_tokens" : MAX_REASONING_TOKENS_ANTHROPIC }
167
+ max_tokens += MAX_REASONING_TOKENS_ANTHROPIC
168
+ # Temperature control not supported when using extended thinking
169
+ temperature = 1.0
170
+
136
171
formatted_messages : List [anthropic .types .MessageParam ] = [
137
172
anthropic .types .MessageParam (role = message .role , content = message .content ) for message in messages
138
173
]
@@ -145,8 +180,8 @@ def anthropic_llm_thread(
145
180
temperature = temperature ,
146
181
system = system_prompt ,
147
182
timeout = 20 ,
148
- max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC ,
149
- ** ( model_kwargs or dict ()) ,
183
+ max_tokens = max_tokens ,
184
+ ** model_kwargs ,
150
185
) as stream :
151
186
for text in stream .text_stream :
152
187
aggregated_response += text
0 commit comments