Skip to content

Commit 0c5e610

Browse files
committed
LCORE-411: add token usage metrics
Signed-off-by: Haoyu Sun <[email protected]>
1 parent 9c51e77 commit 0c5e610

File tree

6 files changed

+128
-4
lines changed

6 files changed

+128
-4
lines changed

src/app/endpoints/query.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from configuration import configuration
2525
from app.database import get_session
2626
import metrics
27+
from metrics.utils import update_llm_token_count_from_turn
2728
import constants
2829
from authorization.middleware import authorize
2930
from models.config import Action
@@ -218,6 +219,7 @@ async def query_endpoint_handler(
218219
query_request,
219220
token,
220221
mcp_headers=mcp_headers,
222+
provider_id=provider_id,
221223
)
222224
# Update metrics for the LLM call
223225
metrics.llm_calls_total.labels(provider_id, model_id).inc()
@@ -387,12 +389,14 @@ def is_input_shield(shield: Shield) -> bool:
387389
return _is_inout_shield(shield) or not is_output_shield(shield)
388390

389391

390-
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches
392+
async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments
391393
client: AsyncLlamaStackClient,
392394
model_id: str,
393395
query_request: QueryRequest,
394396
token: str,
395397
mcp_headers: dict[str, dict[str, str]] | None = None,
398+
*,
399+
provider_id: str = "",
396400
) -> tuple[TurnSummary, str]:
397401
"""
398402
Retrieve response from LLMs and agents.
@@ -411,6 +415,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
411415
412416
Parameters:
413417
model_id (str): The identifier of the LLM model to use.
418+
provider_id (str): The identifier of the LLM provider to use.
414419
query_request (QueryRequest): The user's query and associated metadata.
415420
token (str): The authentication token for authorization.
416421
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
@@ -510,6 +515,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
510515
tool_calls=[],
511516
)
512517

518+
# Update token count metrics for the LLM call
519+
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
520+
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
521+
513522
# Check for validation errors in the response
514523
steps = response.steps or []
515524
for step in steps:

src/app/endpoints/streaming_query.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from client import AsyncLlamaStackClientHolder
2727
from configuration import configuration
2828
import metrics
29+
from metrics.utils import update_llm_token_count_from_turn
2930
from models.config import Action
3031
from models.requests import QueryRequest
3132
from models.database.conversations import UserConversation
@@ -621,6 +622,13 @@ async def response_generator(
621622
summary.llm_response = interleaved_content_as_str(
622623
p.turn.output_message.content
623624
)
625+
system_prompt = get_system_prompt(query_request, configuration)
626+
try:
627+
update_llm_token_count_from_turn(
628+
p.turn, model_id, provider_id, system_prompt
629+
)
630+
except Exception: # pylint: disable=broad-except
631+
logger.exception("Failed to update token usage metrics")
624632
elif p.event_type == "step_complete":
625633
if p.step_details.step_type == "tool_execution":
626634
summary.append_tool_calls_from_llama(p.step_details)

src/metrics/utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
"""Utility functions for metrics handling."""
22

3-
from configuration import configuration
3+
from typing import cast
4+
5+
from llama_stack.models.llama.datatypes import RawMessage
6+
from llama_stack.models.llama.llama3.chat_format import ChatFormat
7+
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
8+
from llama_stack_client.types.agents.turn import Turn
9+
10+
import metrics
411
from client import AsyncLlamaStackClientHolder
12+
from configuration import configuration
513
from log import get_logger
6-
import metrics
714
from utils.common import run_once_async
815

916
logger = get_logger(__name__)
@@ -48,3 +55,23 @@ async def setup_model_metrics() -> None:
4855
default_model_value,
4956
)
5057
logger.info("Model metrics setup complete")
58+
59+
60+
def update_llm_token_count_from_turn(
61+
turn: Turn, model: str, provider: str, system_prompt: str = ""
62+
) -> None:
63+
"""Update the LLM calls metrics from a turn."""
64+
tokenizer = Tokenizer.get_instance()
65+
formatter = ChatFormat(tokenizer)
66+
67+
raw_message = cast(RawMessage, turn.output_message)
68+
encoded_output = formatter.encode_dialog_prompt([raw_message])
69+
token_count = len(encoded_output.tokens) if encoded_output.tokens else 0
70+
metrics.llm_token_received_total.labels(provider, model).inc(token_count)
71+
72+
input_messages = [RawMessage(role="user", content=system_prompt)] + cast(
73+
list[RawMessage], turn.input_messages
74+
)
75+
encoded_input = formatter.encode_dialog_prompt(input_messages)
76+
token_count = len(encoded_input.tokens) if encoded_input.tokens else 0
77+
metrics.llm_token_sent_total.labels(provider, model).inc(token_count)

tests/unit/app/endpoints/test_query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ def dummy_request() -> Request:
4646
return req
4747

4848

49+
def mock_metrics(mocker):
50+
"""Helper function to mock metrics operations for query endpoints."""
51+
mocker.patch(
52+
"app.endpoints.query.update_llm_token_count_from_turn",
53+
return_value=None,
54+
)
55+
56+
4957
def mock_database_operations(mocker):
5058
"""Helper function to mock database operations for query endpoints."""
5159
mocker.patch(
@@ -443,6 +451,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker
443451
"app.endpoints.query.get_agent",
444452
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
445453
)
454+
mock_metrics(mocker)
446455

447456
query_request = QueryRequest(query="What is OpenStack?")
448457
model_id = "fake_model_id"
@@ -474,6 +483,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo
474483
"app.endpoints.query.get_agent",
475484
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
476485
)
486+
mock_metrics(mocker)
477487

478488
query_request = QueryRequest(query="What is OpenStack?")
479489
model_id = "fake_model_id"
@@ -506,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker
506516
"app.endpoints.query.get_agent",
507517
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
508518
)
519+
mock_metrics(mocker)
509520

510521
query_request = QueryRequest(query="What is OpenStack?")
511522
model_id = "fake_model_id"
@@ -544,6 +555,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke
544555
"app.endpoints.query.get_agent",
545556
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
546557
)
558+
mock_metrics(mocker)
547559

548560
query_request = QueryRequest(query="What is OpenStack?")
549561
model_id = "fake_model_id"
@@ -593,6 +605,7 @@ def __repr__(self):
593605
"app.endpoints.query.get_agent",
594606
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
595607
)
608+
mock_metrics(mocker)
596609

597610
query_request = QueryRequest(query="What is OpenStack?")
598611
model_id = "fake_model_id"
@@ -645,6 +658,7 @@ def __repr__(self):
645658
"app.endpoints.query.get_agent",
646659
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
647660
)
661+
mock_metrics(mocker)
648662

649663
query_request = QueryRequest(query="What is OpenStack?")
650664
model_id = "fake_model_id"
@@ -699,6 +713,7 @@ def __repr__(self):
699713
"app.endpoints.query.get_agent",
700714
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
701715
)
716+
mock_metrics(mocker)
702717

703718
query_request = QueryRequest(query="What is OpenStack?")
704719
model_id = "fake_model_id"
@@ -755,6 +770,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker
755770
"app.endpoints.query.get_agent",
756771
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
757772
)
773+
mock_metrics(mocker)
758774

759775
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
760776
model_id = "fake_model_id"
@@ -809,6 +825,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke
809825
"app.endpoints.query.get_agent",
810826
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
811827
)
828+
mock_metrics(mocker)
812829

813830
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
814831
model_id = "fake_model_id"
@@ -864,6 +881,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
864881
"app.endpoints.query.get_agent",
865882
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
866883
)
884+
mock_metrics(mocker)
867885

868886
query_request = QueryRequest(query="What is OpenStack?")
869887
model_id = "fake_model_id"
@@ -933,6 +951,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
933951
"app.endpoints.query.get_agent",
934952
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
935953
)
954+
mock_metrics(mocker)
936955

937956
query_request = QueryRequest(query="What is OpenStack?")
938957
model_id = "fake_model_id"
@@ -994,6 +1013,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(
9941013
"app.endpoints.query.get_agent",
9951014
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
9961015
)
1016+
mock_metrics(mocker)
9971017

9981018
query_request = QueryRequest(query="What is OpenStack?")
9991019
model_id = "fake_model_id"
@@ -1090,6 +1110,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker):
10901110
"app.endpoints.query.get_agent",
10911111
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
10921112
)
1113+
mock_metrics(mocker)
10931114

10941115
query_request = QueryRequest(query="What is OpenStack?")
10951116

@@ -1326,6 +1347,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
13261347
"app.endpoints.query.get_agent",
13271348
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13281349
)
1350+
mock_metrics(mocker)
13291351

13301352
query_request = QueryRequest(query="What is OpenStack?", no_tools=True)
13311353
model_id = "fake_model_id"
@@ -1376,6 +1398,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
13761398
"app.endpoints.query.get_agent",
13771399
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13781400
)
1401+
mock_metrics(mocker)
13791402

13801403
query_request = QueryRequest(query="What is OpenStack?", no_tools=False)
13811404
model_id = "fake_model_id"

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def mock_database_operations(mocker):
5858
mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details")
5959

6060

61+
def mock_metrics(mocker):
62+
"""Helper function to mock metrics operations for streaming query endpoints."""
63+
mocker.patch(
64+
"app.endpoints.streaming_query.update_llm_token_count_from_turn",
65+
return_value=None,
66+
)
67+
68+
6169
SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [
6270
"""knowledge_search tool found 2 chunks:
6371
BEGIN of knowledge_search tool results.
@@ -346,12 +354,14 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
346354
@pytest.mark.asyncio
347355
async def test_streaming_query_endpoint_handler(mocker):
348356
"""Test the streaming query endpoint handler with transcript storage disabled."""
357+
mock_metrics(mocker)
349358
await _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
350359

351360

352361
@pytest.mark.asyncio
353362
async def test_streaming_query_endpoint_handler_store_transcript(mocker):
354363
"""Test the streaming query endpoint handler with transcript storage enabled."""
364+
mock_metrics(mocker)
355365
await _test_streaming_query_endpoint_handler(mocker, store_transcript=True)
356366

357367

tests/unit/metrics/test_utis.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for functions defined in metrics/utils.py"""
22

3-
from metrics.utils import setup_model_metrics
3+
from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn
44

55

66
async def test_setup_model_metrics(mocker):
@@ -74,3 +74,50 @@ async def test_setup_model_metrics(mocker):
7474
],
7575
any_order=False, # Order matters here
7676
)
77+
78+
79+
def test_update_llm_token_count_from_turn(mocker):
80+
"""Test the update_llm_token_count_from_turn function."""
81+
mocker.patch("metrics.utils.Tokenizer.get_instance")
82+
mock_formatter_class = mocker.patch("metrics.utils.ChatFormat")
83+
mock_formatter = mocker.Mock()
84+
mock_formatter_class.return_value = mock_formatter
85+
86+
mock_received_metric = mocker.patch(
87+
"metrics.utils.metrics.llm_token_received_total"
88+
)
89+
mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total")
90+
91+
mock_turn = mocker.Mock()
92+
# turn.output_message should satisfy the type RawMessage
93+
mock_turn.output_message = {"role": "assistant", "content": "test response"}
94+
# turn.input_messages should satisfy the type list[RawMessage]
95+
mock_turn.input_messages = [{"role": "user", "content": "test input"}]
96+
97+
# Mock the encoded results with tokens
98+
mock_encoded_output = mocker.Mock()
99+
mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens
100+
mock_encoded_input = mocker.Mock()
101+
mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens
102+
mock_formatter.encode_dialog_prompt.side_effect = [
103+
mock_encoded_output,
104+
mock_encoded_input,
105+
]
106+
107+
test_model = "test_model"
108+
test_provider = "test_provider"
109+
test_system_prompt = "test system prompt"
110+
111+
update_llm_token_count_from_turn(
112+
mock_turn, test_model, test_provider, test_system_prompt
113+
)
114+
115+
# Verify that llm_token_received_total.labels() was called with correct metrics
116+
mock_received_metric.labels.assert_called_once_with(test_provider, test_model)
117+
mock_received_metric.labels().inc.assert_called_once_with(
118+
3
119+
) # token count from output
120+
121+
# Verify that llm_token_sent_total.labels() was called with correct metrics
122+
mock_sent_metric.labels.assert_called_once_with(test_provider, test_model)
123+
mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input

0 commit comments

Comments
 (0)