Skip to content

Commit cd5bb03

Browse files
committed
Lint picking
1 parent 4ef7798 commit cd5bb03

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

src/utils/token_counter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import Sequence
1212

1313
from cachetools import TTLCache # type: ignore
14+
import tiktoken
15+
1416

1517
from llama_stack_client.types import (
1618
UserMessage,
@@ -19,7 +21,6 @@
1921
CompletionMessage,
2022
)
2123
from models.requests import QueryRequest
22-
import tiktoken
2324

2425
from configuration import configuration, AppConfig
2526
from constants import DEFAULT_ESTIMATION_TOKENIZER
@@ -127,7 +128,11 @@ def count_turn_tokens(
127128
}
128129

129130
def count_conversation_turn_tokens(
130-
self, conversation_id: str, system_prompt: str, query_request: QueryRequest, response: str = ""
131+
self,
132+
conversation_id: str,
133+
system_prompt: str,
134+
query_request: QueryRequest,
135+
response: str = "",
131136
) -> dict[str, int]:
132137
"""Count tokens for a conversation turn with cumulative tracking.
133138
@@ -148,7 +153,9 @@ def count_conversation_turn_tokens(
148153
- 'output_tokens': Total tokens in the response message
149154
"""
150155
# Get the current turn's token usage
151-
turn_token_usage = self.count_turn_tokens(system_prompt, query_request, response)
156+
turn_token_usage = self.count_turn_tokens(
157+
system_prompt, query_request, response
158+
)
152159

153160
# Get cumulative input tokens for this conversation
154161
cumulative_input_tokens = _conversation_cache.get(conversation_id, 0)

tests/unit/app/endpoints/test_query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
760760
},
761761
}
762762

763-
response, conversation_id, token_usage = retrieve_response(
763+
response, conversation_id, _ = retrieve_response(
764764
mock_client,
765765
model_id,
766766
query_request,
@@ -1204,7 +1204,11 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker):
12041204

12051205
mock_retrieve_response = mocker.patch(
12061206
"app.endpoints.query.retrieve_response",
1207-
return_value=("test response", "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}),
1207+
return_value=(
1208+
"test response",
1209+
"test_conversation_id",
1210+
{"input_tokens": 10, "output_tokens": 20},
1211+
),
12081212
)
12091213

12101214
mocker.patch("app.endpoints.query.select_model_id", return_value="test_model")

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
151151

152152
# Mock the streaming response from LLama Stack
153153
mock_streaming_response = mocker.AsyncMock()
154-
# Currently usage is not returned by the API, we simulate by using del to prevent pytest from returning a Mock
154+
# Currently usage is not returned by the API
155+
# we simulate by using del to prevent pytest from returning a Mock
155156
del mock_streaming_response.usage
156157
mock_streaming_response.__aiter__.return_value = [
157158
mocker.Mock(
@@ -862,7 +863,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
862863
},
863864
}
864865

865-
response, conversation_id, token_usage = await retrieve_response(
866+
response, conversation_id, _ = await retrieve_response(
866867
mock_client,
867868
model_id,
868869
query_request,
@@ -1224,7 +1225,11 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker):
12241225
mock_streaming_response.__aiter__.return_value = iter([])
12251226
mock_retrieve_response = mocker.patch(
12261227
"app.endpoints.streaming_query.retrieve_response",
1227-
return_value=(mock_streaming_response, "test_conversation_id", {"input_tokens": 10, "output_tokens": 20}),
1228+
return_value=(
1229+
mock_streaming_response,
1230+
"test_conversation_id",
1231+
{"input_tokens": 10, "output_tokens": 20},
1232+
),
12281233
)
12291234

12301235
mocker.patch(

tests/unit/utils/test_token_counter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Unit tests for token counter utilities."""
22

3-
from utils.token_counter import TokenCounter
43
from llama_stack_client.types import UserMessage, CompletionMessage
4+
5+
from utils.token_counter import TokenCounter
56
from models.requests import QueryRequest, Attachment
67
from configuration import AppConfig
78

@@ -31,6 +32,7 @@ class TestTokenCounter:
3132
"""Test cases for TokenCounter class."""
3233

3334
def setup_class(self):
35+
"""Setup the test class."""
3436
cfg = AppConfig()
3537
cfg.init_from_dict(config_dict)
3638

@@ -40,6 +42,7 @@ def test_count_tokens_empty_string(self):
4042
assert counter.count_tokens("") == 0
4143

4244
def test_count_tokens_simple(self):
45+
"""Test counting tokens for a simple message."""
4346
counter = TokenCounter("llama3.2:1b")
4447
assert counter.count_tokens("Hello World!") == 3
4548

@@ -104,21 +107,18 @@ def test_count_conversation_turn_tokens_with_attachments(self):
104107
Attachment(
105108
attachment_type="configuration",
106109
content_type="application/yaml",
107-
content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n containers:\n - name: app",
110+
content="kind: Pod\nmetadata:\n name: test-pod\nspec:\n"
111+
+ " containers:\n - name: app\n image: nginx:latest",
108112
),
109113
]
110114

111115
query_request = QueryRequest(
112-
query="Analyze these files for me",
113-
attachments=attachments
116+
query="Analyze these files for me", attachments=attachments
114117
)
115118

116119
# Test the conversation turn with attachments
117120
result = counter.count_conversation_turn_tokens(
118-
"conv_with_attachments",
119-
"System prompt",
120-
query_request,
121-
"Analysis complete"
121+
"conv_with_attachments", "System prompt", query_request, "Analysis complete"
122122
)
123123

124124
# Verify that the result contains the expected structure
@@ -142,7 +142,7 @@ def test_count_conversation_turn_tokens_with_attachments(self):
142142
"conv_no_attachments",
143143
"System prompt",
144144
query_request_no_attachments,
145-
"Analysis complete"
145+
"Analysis complete",
146146
)
147147

148148
# The version with attachments should have more input tokens

0 commit comments

Comments
 (0)