Skip to content

Commit 5a51e0c

Browse files
Moved token recording to DB
1 parent 6b8c3b1 commit 5a51e0c

File tree

8 files changed

+300
-97
lines changed

8 files changed

+300
-97
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""add token usage columns
2+
3+
Revision ID: 0c3539f66339
4+
Revises: 0f9b8edc8e46
5+
Create Date: 2025-01-28 09:15:54.767311+00:00
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision: str = "0c3539f66339"
15+
down_revision: Union[str, None] = "0f9b8edc8e46"
16+
branch_labels: Union[str, Sequence[str], None] = None
17+
depends_on: Union[str, Sequence[str], None] = None
18+
19+
20+
def upgrade() -> None:
21+
# Begin transaction
22+
op.execute("BEGIN TRANSACTION;")
23+
24+
# We add the columns to the outputs table
25+
# Add the columns with default values to avoid issues with the existing data
26+
# The prices of the tokens may change in the future,
27+
# so we need to store the cost of the tokens at the time of the request
28+
op.execute("ALTER TABLE outputs ADD COLUMN input_tokens INT DEFAULT NULL;")
29+
op.execute("ALTER TABLE outputs ADD COLUMN output_tokens INT DEFAULT NULL;")
30+
op.execute("ALTER TABLE outputs ADD COLUMN input_cost FLOAT DEFAULT NULL;")
31+
op.execute("ALTER TABLE outputs ADD COLUMN output_cost FLOAT DEFAULT NULL;")
32+
33+
# Finish transaction
34+
op.execute("COMMIT;")
35+
36+
37+
def downgrade() -> None:
38+
# Begin transaction
39+
op.execute("BEGIN TRANSACTION;")
40+
41+
op.execute("ALTER TABLE outputs DROP COLUMN input_tokens;")
42+
op.execute("ALTER TABLE outputs DROP COLUMN output_tokens;")
43+
op.execute("ALTER TABLE outputs DROP COLUMN input_cost;")
44+
op.execute("ALTER TABLE outputs DROP COLUMN output_cost;")
45+
46+
# Finish transaction
47+
op.execute("COMMIT;")

src/codegate/api/v1.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -479,16 +479,19 @@ def version_check():
479479
)
480480
async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsageAggregate:
481481
"""Get the token usage of a workspace."""
482-
# TODO: This is a dummy implementation. In the future, we should have a proper
483-
# implementation that fetches the token usage from the database.
484-
return v1_models.TokenUsageAggregate(
485-
used_tokens=50,
486-
tokens_by_model=[
487-
v1_models.TokenUsageByModel(
488-
provider_type="openai", model="gpt-4o-mini", used_tokens=20
489-
),
490-
v1_models.TokenUsageByModel(
491-
provider_type="anthropic", model="claude-3-5-sonnet-20241022", used_tokens=30
492-
),
493-
],
494-
)
482+
483+
try:
484+
ws = await wscrud.get_workspace_by_name(workspace_name)
485+
except crud.WorkspaceDoesNotExistError:
486+
raise HTTPException(status_code=404, detail="Workspace does not exist")
487+
except Exception:
488+
logger.exception("Error while getting workspace")
489+
raise HTTPException(status_code=500, detail="Internal server error")
490+
491+
try:
492+
prompts_outputs = await dbreader.get_prompts_with_output(ws.id)
493+
ws_token_usage = await v1_processing.parse_workspace_token_usage(prompts_outputs)
494+
return ws_token_usage
495+
except Exception:
496+
logger.exception("Error while getting messages")
497+
raise HTTPException(status_code=500, detail="Internal server error")

src/codegate/api/v1_models.py

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@
33
from typing import Any, Dict, List, Optional, Union
44

55
import pydantic
6-
import requests
7-
from cachetools import TTLCache
86

97
from codegate.db import models as db_models
108
from codegate.pipeline.base import CodeSnippet
119

12-
# 1 day cache. Not keep all the models in the cache. Just the ones we have used recently.
13-
model_cost_cache = TTLCache(maxsize=2000, ttl=1 * 24 * 60 * 60)
14-
1510

1611
class Workspace(pydantic.BaseModel):
1712
name: str
@@ -118,46 +113,8 @@ class ProviderType(str, Enum):
118113
openai = "openai"
119114
anthropic = "anthropic"
120115
vllm = "vllm"
121-
122-
123-
class TokenUsage(pydantic.BaseModel):
124-
input_tokens: int = 0
125-
output_tokens: int = 0
126-
input_cost: float = 0
127-
output_cost: float = 0
128-
129-
@classmethod
130-
def from_dict(cls, usage_dict: Dict) -> "TokenUsage":
131-
return cls(
132-
input_tokens=usage_dict.get("prompt_tokens", 0) or usage_dict.get("input_tokens", 0),
133-
output_tokens=usage_dict.get("completion_tokens", 0)
134-
or usage_dict.get("output_tokens", 0),
135-
input_cost=0,
136-
output_cost=0,
137-
)
138-
139-
def __add__(self, other: "TokenUsage") -> "TokenUsage":
140-
return TokenUsage(
141-
input_tokens=self.input_tokens + other.input_tokens,
142-
output_tokens=self.output_tokens + other.output_tokens,
143-
input_cost=self.input_cost + other.input_cost,
144-
output_cost=self.output_cost + other.output_cost,
145-
)
146-
147-
def update_token_cost(self, model: str) -> None:
148-
if not model_cost_cache:
149-
model_cost = requests.get(
150-
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
151-
)
152-
model_cost_cache.update(model_cost.json())
153-
model_cost = model_cost_cache.get(model, {})
154-
input_cost_per_token = model_cost.get("input_cost_per_token", 0)
155-
output_cost_per_token = model_cost.get("output_cost_per_token", 0)
156-
self.input_cost = self.input_tokens * input_cost_per_token
157-
self.output_cost = self.output_tokens * output_cost_per_token
158-
159-
def update_costs_based_on_model(self, model: str):
160-
pass
116+
llamacpp = "llamacpp"
117+
ollama = "ollama"
161118

162119

163120
class TokenUsageByModel(pydantic.BaseModel):
@@ -167,7 +124,7 @@ class TokenUsageByModel(pydantic.BaseModel):
167124

168125
provider_type: ProviderType
169126
model: str
170-
token_usage: TokenUsage
127+
token_usage: db_models.TokenUsage
171128

172129

173130
class TokenUsageAggregate(pydantic.BaseModel):
@@ -177,9 +134,20 @@ class TokenUsageAggregate(pydantic.BaseModel):
177134
"""
178135

179136
tokens_by_model: Dict[str, TokenUsageByModel]
180-
token_usage: TokenUsage
137+
token_usage: db_models.TokenUsage
181138

182139
def add_model_token_usage(self, model_token_usage: TokenUsageByModel) -> None:
140+
# Copilot doesn't have a model name and we cannot obtain the tokens used. Skip it.
141+
if model_token_usage.model == "":
142+
return
143+
144+
# Skip if the model has not used any tokens.
145+
if (
146+
model_token_usage.token_usage.input_tokens == 0
147+
and model_token_usage.token_usage.output_tokens == 0
148+
):
149+
return
150+
183151
if model_token_usage.model in self.tokens_by_model:
184152
self.tokens_by_model[
185153
model_token_usage.model

src/codegate/api/v1_processing.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
PartialQuestionAnswer,
1515
PartialQuestions,
1616
QuestionAnswer,
17-
TokenUsage,
1817
TokenUsageAggregate,
1918
TokenUsageByModel,
2019
)
2120
from codegate.db.connection import alert_queue
22-
from codegate.db.models import Alert, GetPromptWithOutputsRow
21+
from codegate.db.models import Alert, GetPromptWithOutputsRow, TokenUsage
2322

2423
logger = structlog.get_logger("codegate")
2524

@@ -103,55 +102,54 @@ async def parse_request(request_str: str) -> Tuple[Optional[List[str]], str]:
103102
return messages, model
104103

105104

106-
async def parse_output(output_str: str) -> Tuple[Optional[str], TokenUsage]:
105+
async def parse_output(output_str: str) -> Optional[str]:
107106
"""
108107
Parse the output string from the pipeline and return the message.
109108
"""
110109
try:
111110
if output_str is None:
112-
return None, TokenUsage()
111+
return None
113112

114113
output = json.loads(output_str)
115114
except Exception as e:
116115
logger.warning(f"Error parsing output: {output_str}. {e}")
117-
return None, TokenUsage()
116+
return None
118117

119-
def _parse_single_output(single_output: dict) -> Tuple[str, TokenUsage]:
118+
def _parse_single_output(single_output: dict) -> str:
120119
single_output_message = ""
121120
for choice in single_output.get("choices", []):
122121
if not isinstance(choice, dict):
123122
continue
124123
content_dict = choice.get("delta", {}) or choice.get("message", {})
125124
single_output_message += content_dict.get("content", "")
126-
return single_output_message, TokenUsage.from_dict(single_output.get("usage", {}))
125+
return single_output_message
127126

128127
full_output_message = ""
129-
full_token_usage = TokenUsage()
130128
if isinstance(output, list):
131129
for output_chunk in output:
132130
output_message = ""
133-
token_usage = TokenUsage()
134131
if isinstance(output_chunk, dict):
135-
output_message, token_usage = _parse_single_output(output_chunk)
132+
output_message = _parse_single_output(output_chunk)
136133
elif isinstance(output_chunk, str):
137134
try:
138135
output_decoded = json.loads(output_chunk)
139-
output_message, token_usage = _parse_single_output(output_decoded)
136+
output_message = _parse_single_output(output_decoded)
140137
except Exception:
141138
logger.error(f"Error reading chunk: {output_chunk}")
142139
else:
143140
logger.warning(
144141
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
145142
)
146143
full_output_message += output_message
147-
full_token_usage += token_usage
148144
elif isinstance(output, dict):
149-
full_output_message, full_token_usage = _parse_single_output(output)
145+
full_output_message = _parse_single_output(output)
150146

151-
return full_output_message, full_token_usage
147+
return full_output_message
152148

153149

154-
async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[PartialQuestionAnswer]:
150+
async def _get_partial_question_answer(
151+
row: GetPromptWithOutputsRow,
152+
) -> Optional[PartialQuestionAnswer]:
155153
"""
156154
Parse a row from the get_prompt_with_outputs query and return a PartialConversation
157155
@@ -162,7 +160,7 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
162160
output_task = tg.create_task(parse_output(row.output))
163161

164162
request_user_msgs, model = request_task.result()
165-
output_msg_str, token_usage = output_task.result()
163+
output_msg_str = output_task.result()
166164

167165
# If we couldn't parse the request, return None
168166
if not request_user_msgs:
@@ -184,8 +182,13 @@ async def _get_question_answer(row: GetPromptWithOutputsRow) -> Optional[Partial
184182
else:
185183
output_message = None
186184

185+
token_usage = TokenUsage.from_db(
186+
input_cost=row.input_cost,
187+
input_tokens=row.input_tokens,
188+
output_tokens=row.output_tokens,
189+
output_cost=row.output_cost,
190+
)
187191
# Use the model to update the token cost
188-
token_usage.update_token_cost(model)
189192
provider = row.provider
190193
# TODO: This should come from the database. For now, we are manually changing copilot to openai
191194
# Change copilot provider to openai
@@ -297,7 +300,8 @@ def _get_question_answer_from_partial(
297300
partial_question_answer: PartialQuestionAnswer,
298301
) -> QuestionAnswer:
299302
"""
300-
Get a QuestionAnswer object from a PartialQuestionAnswer object.
303+
Get a QuestionAnswer object from a PartialQuestionAnswer object. PartialQuestionAnswer
304+
contains a list of messages as question. QuestionAnswer contains a single message as question.
301305
"""
302306
# Get the last user message as the question
303307
question = ChatMessage(
@@ -315,11 +319,8 @@ async def match_conversations(
315319
"""
316320
Match partial conversations to form a complete conversation.
317321
"""
318-
valid_partial_qas = [
319-
partial_qas for partial_qas in partial_question_answers if partial_qas is not None
320-
]
321322
grouped_partial_questions = _group_partial_messages(
322-
[partial_qs_a.partial_questions for partial_qs_a in valid_partial_qas]
323+
[partial_qs_a.partial_questions for partial_qs_a in partial_question_answers]
323324
)
324325

325326
# Create the conversation objects
@@ -333,7 +334,7 @@ async def match_conversations(
333334
# Partial questions don't contain the answer, so we need to find the corresponding
334335
# valid partial question answer
335336
selected_partial_qa = None
336-
for partial_qa in valid_partial_qas:
337+
for partial_qa in partial_question_answers:
337338
if partial_question.message_id == partial_qa.partial_questions.message_id:
338339
selected_partial_qa = partial_qa
339340
break
@@ -367,17 +368,25 @@ async def match_conversations(
367368
return conversations, map_q_id_to_conversation
368369

369370

371+
async def _process_prompt_output_to_partial_qa(
372+
prompts_outputs: List[GetPromptWithOutputsRow],
373+
) -> List[PartialQuestionAnswer]:
374+
"""
375+
Process the prompts and outputs to PartialQuestionAnswer objects.
376+
"""
377+
# Parse the prompts and outputs in parallel
378+
async with asyncio.TaskGroup() as tg:
379+
tasks = [tg.create_task(_get_partial_question_answer(row)) for row in prompts_outputs]
380+
return [task.result() for task in tasks if task.result() is not None]
381+
382+
370383
async def parse_messages_in_conversations(
371384
prompts_outputs: List[GetPromptWithOutputsRow],
372385
) -> Tuple[List[Conversation], Dict[str, Conversation]]:
373386
"""
374387
Get all the messages from the database and return them as a list of conversations.
375388
"""
376-
377-
# Parse the prompts and outputs in parallel
378-
async with asyncio.TaskGroup() as tg:
379-
tasks = [tg.create_task(_get_question_answer(row)) for row in prompts_outputs]
380-
partial_question_answers = [task.result() for task in tasks]
389+
partial_question_answers = await _process_prompt_output_to_partial_qa(prompts_outputs)
381390

382391
conversations, map_q_id_to_conversation = await match_conversations(partial_question_answers)
383392
return conversations, map_q_id_to_conversation
@@ -430,3 +439,16 @@ async def parse_get_alert_conversation(
430439
for row in alerts
431440
]
432441
return [task.result() for task in tasks if task.result() is not None]
442+
443+
444+
async def parse_workspace_token_usage(
445+
prompts_outputs: List[GetPromptWithOutputsRow],
446+
) -> TokenUsageAggregate:
447+
"""
448+
Parse the token usage from the workspace.
449+
"""
450+
partial_question_answers = await _process_prompt_output_to_partial_qa(prompts_outputs)
451+
token_usage_agg = TokenUsageAggregate(tokens_by_model={}, token_usage=TokenUsage())
452+
for p_qa in partial_question_answers:
453+
token_usage_agg.add_model_token_usage(p_qa.model_token_usage)
454+
return token_usage_agg

0 commit comments

Comments
 (0)