14
14
PartialQuestionAnswer ,
15
15
PartialQuestions ,
16
16
QuestionAnswer ,
17
- TokenUsage ,
18
17
TokenUsageAggregate ,
19
18
TokenUsageByModel ,
20
19
)
21
20
from codegate .db .connection import alert_queue
22
- from codegate .db .models import Alert , GetPromptWithOutputsRow
21
+ from codegate .db .models import Alert , GetPromptWithOutputsRow , TokenUsage
23
22
24
23
logger = structlog .get_logger ("codegate" )
25
24
@@ -103,55 +102,54 @@ async def parse_request(request_str: str) -> Tuple[Optional[List[str]], str]:
103
102
return messages , model
104
103
105
104
106
- async def parse_output (output_str : str ) -> Tuple [ Optional [str ], TokenUsage ]:
105
+ async def parse_output (output_str : str ) -> Optional [str ]:
107
106
"""
108
107
Parse the output string from the pipeline and return the message.
109
108
"""
110
109
try :
111
110
if output_str is None :
112
- return None , TokenUsage ()
111
+ return None
113
112
114
113
output = json .loads (output_str )
115
114
except Exception as e :
116
115
logger .warning (f"Error parsing output: { output_str } . { e } " )
117
- return None , TokenUsage ()
116
+ return None
118
117
119
- def _parse_single_output (single_output : dict ) -> Tuple [ str , TokenUsage ] :
118
+ def _parse_single_output (single_output : dict ) -> str :
120
119
single_output_message = ""
121
120
for choice in single_output .get ("choices" , []):
122
121
if not isinstance (choice , dict ):
123
122
continue
124
123
content_dict = choice .get ("delta" , {}) or choice .get ("message" , {})
125
124
single_output_message += content_dict .get ("content" , "" )
126
- return single_output_message , TokenUsage . from_dict ( single_output . get ( "usage" , {}))
125
+ return single_output_message
127
126
128
127
full_output_message = ""
129
- full_token_usage = TokenUsage ()
130
128
if isinstance (output , list ):
131
129
for output_chunk in output :
132
130
output_message = ""
133
- token_usage = TokenUsage ()
134
131
if isinstance (output_chunk , dict ):
135
- output_message , token_usage = _parse_single_output (output_chunk )
132
+ output_message = _parse_single_output (output_chunk )
136
133
elif isinstance (output_chunk , str ):
137
134
try :
138
135
output_decoded = json .loads (output_chunk )
139
- output_message , token_usage = _parse_single_output (output_decoded )
136
+ output_message = _parse_single_output (output_decoded )
140
137
except Exception :
141
138
logger .error (f"Error reading chunk: { output_chunk } " )
142
139
else :
143
140
logger .warning (
144
141
f"Could not handle output: { output_chunk } " , out_type = type (output_chunk )
145
142
)
146
143
full_output_message += output_message
147
- full_token_usage += token_usage
148
144
elif isinstance (output , dict ):
149
- full_output_message , full_token_usage = _parse_single_output (output )
145
+ full_output_message = _parse_single_output (output )
150
146
151
- return full_output_message , full_token_usage
147
+ return full_output_message
152
148
153
149
154
- async def _get_question_answer (row : GetPromptWithOutputsRow ) -> Optional [PartialQuestionAnswer ]:
150
+ async def _get_partial_question_answer (
151
+ row : GetPromptWithOutputsRow ,
152
+ ) -> Optional [PartialQuestionAnswer ]:
155
153
"""
156
154
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
157
155
@@ -162,7 +160,7 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
162
160
output_task = tg .create_task (parse_output (row .output ))
163
161
164
162
request_user_msgs , model = request_task .result ()
165
- output_msg_str , token_usage = output_task .result ()
163
+ output_msg_str = output_task .result ()
166
164
167
165
# If we couldn't parse the request, return None
168
166
if not request_user_msgs :
@@ -184,8 +182,13 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
184
182
else :
185
183
output_message = None
186
184
185
+ token_usage = TokenUsage .from_db (
186
+ input_cost = row .input_cost ,
187
+ input_tokens = row .input_tokens ,
188
+ output_tokens = row .output_tokens ,
189
+ output_cost = row .output_cost ,
190
+ )
187
191
# Use the model to update the token cost
188
- token_usage .update_token_cost (model )
189
192
provider = row .provider
190
193
# TODO: This should come from the database. For now, we are manually changing copilot to openai
191
194
# Change copilot provider to openai
@@ -297,7 +300,8 @@ def _get_question_answer_from_partial(
297
300
partial_question_answer : PartialQuestionAnswer ,
298
301
) -> QuestionAnswer :
299
302
"""
300
- Get a QuestionAnswer object from a PartialQuestionAnswer object.
303
+ Get a QuestionAnswer object from a PartialQuestionAnswer object. PartialQuestionAnswer
304
+ contains a list of messages as question. QuestionAnswer contains a single message as question.
301
305
"""
302
306
# Get the last user message as the question
303
307
question = ChatMessage (
@@ -315,11 +319,8 @@ async def match_conversations(
315
319
"""
316
320
Match partial conversations to form a complete conversation.
317
321
"""
318
- valid_partial_qas = [
319
- partial_qas for partial_qas in partial_question_answers if partial_qas is not None
320
- ]
321
322
grouped_partial_questions = _group_partial_messages (
322
- [partial_qs_a .partial_questions for partial_qs_a in valid_partial_qas ]
323
+ [partial_qs_a .partial_questions for partial_qs_a in partial_question_answers ]
323
324
)
324
325
325
326
# Create the conversation objects
@@ -333,7 +334,7 @@ async def match_conversations(
333
334
# Partial questions don't contain the answer, so we need to find the corresponding
334
335
# valid partial question answer
335
336
selected_partial_qa = None
336
- for partial_qa in valid_partial_qas :
337
+ for partial_qa in partial_question_answers :
337
338
if partial_question .message_id == partial_qa .partial_questions .message_id :
338
339
selected_partial_qa = partial_qa
339
340
break
@@ -367,17 +368,25 @@ async def match_conversations(
367
368
return conversations , map_q_id_to_conversation
368
369
369
370
371
+ async def _process_prompt_output_to_partial_qa (
372
+ prompts_outputs : List [GetPromptWithOutputsRow ],
373
+ ) -> List [PartialQuestionAnswer ]:
374
+ """
375
+ Process the prompts and outputs to PartialQuestionAnswer objects.
376
+ """
377
+ # Parse the prompts and outputs in parallel
378
+ async with asyncio .TaskGroup () as tg :
379
+ tasks = [tg .create_task (_get_partial_question_answer (row )) for row in prompts_outputs ]
380
+ return [task .result () for task in tasks if task .result () is not None ]
381
+
382
+
370
383
async def parse_messages_in_conversations (
371
384
prompts_outputs : List [GetPromptWithOutputsRow ],
372
385
) -> Tuple [List [Conversation ], Dict [str , Conversation ]]:
373
386
"""
374
387
Get all the messages from the database and return them as a list of conversations.
375
388
"""
376
-
377
- # Parse the prompts and outputs in parallel
378
- async with asyncio .TaskGroup () as tg :
379
- tasks = [tg .create_task (_get_question_answer (row )) for row in prompts_outputs ]
380
- partial_question_answers = [task .result () for task in tasks ]
389
+ partial_question_answers = await _process_prompt_output_to_partial_qa (prompts_outputs )
381
390
382
391
conversations , map_q_id_to_conversation = await match_conversations (partial_question_answers )
383
392
return conversations , map_q_id_to_conversation
@@ -430,3 +439,16 @@ async def parse_get_alert_conversation(
430
439
for row in alerts
431
440
]
432
441
return [task .result () for task in tasks if task .result () is not None ]
442
+
443
+
444
+ async def parse_workspace_token_usage (
445
+ prompts_outputs : List [GetPromptWithOutputsRow ],
446
+ ) -> TokenUsageAggregate :
447
+ """
448
+ Parse the token usage from the workspace.
449
+ """
450
+ partial_question_answers = await _process_prompt_output_to_partial_qa (prompts_outputs )
451
+ token_usage_agg = TokenUsageAggregate (tokens_by_model = {}, token_usage = TokenUsage ())
452
+ for p_qa in partial_question_answers :
453
+ token_usage_agg .add_model_token_usage (p_qa .model_token_usage )
454
+ return token_usage_agg
0 commit comments