Skip to content

Commit ef94176

Browse files
committed
Add RAG chunks in query response
Signed-off-by: Anxhela Coba <[email protected]>
1 parent b14d91c commit ef94176

File tree

6 files changed

+253
-24
lines changed

6 files changed

+253
-24
lines changed

pyproject.toml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,26 @@ dependencies = [
4545
"email-validator>=2.2.0",
4646
"openai==1.99.9",
4747
# Used by database interface
48-
"sqlalchemy>=2.0.42",
48+
"sqlalchemy>=2.0.41",
4949
# Used by Llama Stack version checker
5050
"semver<4.0.0",
5151
# Used by authorization resolvers
5252
"jsonpath-ng>=1.6.1",
53+
"opentelemetry-sdk>=1.34.0",
54+
"opentelemetry-exporter-otlp>=1.34.0",
55+
"opentelemetry-instrumentation>=0.55b0",
56+
"aiosqlite>=0.21.0",
57+
"litellm>=1.72.1",
58+
"blobfile>=3.0.0",
59+
"datasets>=3.6.0",
60+
"faiss-cpu>=1.11.0",
61+
"mcp>=1.9.4",
62+
"autoevals>=0.0.129",
63+
"psutil>=7.0.0",
64+
"torch>=2.7.1",
65+
"peft>=0.15.2",
66+
"trl>=0.18.2",
67+
"sentence-transformers>=5.1.0",
5368
]
5469

5570

run.yaml

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ providers:
6060
provider_id: meta-reference
6161
provider_type: inline::meta-reference
6262
inference:
63+
- provider_id: sentence-transformers # Can be any embedding provider
64+
provider_type: inline::sentence-transformers
65+
config: {}
6366
- provider_id: openai
6467
provider_type: remote::openai
6568
config:
@@ -99,14 +102,17 @@ providers:
99102
- provider_id: model-context-protocol
100103
provider_type: remote::model-context-protocol
101104
config: {}
105+
- provider_id: rag-runtime
106+
provider_type: inline::rag-runtime
107+
config: {}
102108
vector_io:
103109
- config:
104110
kvstore:
105-
db_path: .llama/distributions/ollama/faiss_store.db
111+
db_path: /path/to/your/vector/store.db
106112
namespace: null
107113
type: sqlite
108-
provider_id: faiss
109-
provider_type: inline::faiss
114+
provider_id: my_vector_db
115+
provider_type: inline::faiss # Or prefered vector DB
110116
scoring_fns: []
111117
server:
112118
auth: null
@@ -117,10 +123,23 @@ server:
117123
tls_certfile: null
118124
tls_keyfile: null
119125
shields: []
120-
vector_dbs: []
121-
126+
vector_dbs:
127+
- vector_db_id: my_knowledge_base
128+
embedding_model: sentence-transformers/all-mpnet-base-v2
129+
embedding_dimension: 768
130+
provider_id: my_vector_db
122131
models:
132+
- metadata:
133+
embedding_dimension: 768 # Depends on chosen model
134+
model_id: sentence-transformers/all-mpnet-base-v2 # Example model
135+
provider_id: sentence-transformers
136+
provider_model_id: path/to/model
137+
model_type: embedding
123138
- model_id: gpt-4-turbo
124139
provider_id: openai
125140
model_type: llm
126141
provider_model_id: gpt-4-turbo
142+
143+
tool_groups:
144+
- toolgroup_id: builtin::rag
145+
provider_id: rag-runtime

src/app/endpoints/query.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from models.config import Action
3131
from models.database.conversations import UserConversation
3232
from models.requests import QueryRequest, Attachment
33-
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
33+
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse, RAGChunk, ReferencedDocument, ToolCall
3434
from utils.endpoints import (
3535
check_configuration_loaded,
3636
get_agent,
@@ -243,17 +243,61 @@ async def query_endpoint_handler(
243243
attachments=query_request.attachments or [],
244244
)
245245

246+
logger.info("Persisting conversation details...")
246247
persist_user_conversation_details(
247248
user_id=user_id,
248249
conversation_id=conversation_id,
249250
model=model_id,
250251
provider_id=provider_id,
251252
)
252253

253-
return QueryResponse(
254+
# Convert tool calls and RAG chunks to response format
255+
logger.info("Processing tool calls...")
256+
tool_calls = [
257+
ToolCall(
258+
tool_name=tc.name,
259+
arguments=tc.args if isinstance(tc.args, dict) else {"query": str(tc.args)},
260+
result={"response": tc.response} if tc.response else None
261+
)
262+
for tc in summary.tool_calls
263+
]
264+
265+
266+
logger.info("Processing RAG chunks...")
267+
rag_chunks = [
268+
RAGChunk(
269+
content=chunk.content,
270+
source=chunk.source,
271+
score=chunk.score
272+
)
273+
for chunk in summary.rag_chunks
274+
]
275+
276+
# Extract referenced documents from RAG chunks
277+
logger.info("Extracting referenced documents...")
278+
referenced_docs = []
279+
doc_sources = set()
280+
for chunk in summary.rag_chunks:
281+
if chunk.source and chunk.source not in doc_sources:
282+
doc_sources.add(chunk.source)
283+
referenced_docs.append(
284+
ReferencedDocument(
285+
url=chunk.source if chunk.source.startswith("http") else None,
286+
title=chunk.source,
287+
chunk_count=sum(1 for c in summary.rag_chunks if c.source == chunk.source)
288+
)
289+
)
290+
291+
logger.info("Building final response...")
292+
response = QueryResponse(
254293
conversation_id=conversation_id,
255294
response=summary.llm_response,
295+
rag_chunks=rag_chunks if rag_chunks else None,
296+
referenced_documents=referenced_docs if referenced_docs else None,
297+
tool_calls=tool_calls if tool_calls else None,
256298
)
299+
logger.info("Query processing completed successfully!")
300+
return response
257301

258302
# connection to Llama Stack server
259303
except APIConnectionError as e:

src/models/responses.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Models for REST API responses."""
22

3-
from typing import Any, Optional
3+
from typing import Any, Optional, List
44

55
from pydantic import BaseModel, Field
66

@@ -34,23 +34,45 @@ class ModelsResponse(BaseModel):
3434
)
3535

3636

37-
# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
38-
# we are keeping it simple. The missing fields are:
39-
# - referenced_documents: The optional URLs and titles for the documents used
40-
# to generate the response.
41-
# - truncated: Set to True if conversation history was truncated to be within context window.
42-
# - input_tokens: Number of tokens sent to LLM
43-
# - output_tokens: Number of tokens received from LLM
44-
# - available_quotas: Quota available as measured by all configured quota limiters
45-
# - tool_calls: List of tool requests.
46-
# - tool_results: List of tool results.
47-
# See LLMResponse in ols-service for more details.
37+
class RAGChunk(BaseModel):
38+
"""Model representing a RAG chunk used in the response."""
39+
40+
content: str = Field(description="The content of the chunk")
41+
source: Optional[str] = Field(None, description="Source document or URL")
42+
score: Optional[float] = Field(None, description="Relevance score")
43+
44+
45+
class ReferencedDocument(BaseModel):
46+
"""Model representing a document referenced in the response."""
47+
48+
url: Optional[str] = Field(None, description="URL of the document")
49+
title: Optional[str] = Field(None, description="Title of the document")
50+
chunk_count: Optional[int] = Field(None, description="Number of chunks from this document")
51+
52+
53+
class ToolCall(BaseModel):
54+
"""Model representing a tool call made during response generation."""
55+
56+
tool_name: str = Field(description="Name of the tool called")
57+
arguments: dict[str, Any] = Field(description="Arguments passed to the tool")
58+
result: Optional[dict[str, Any]] = Field(None, description="Result from the tool")
59+
60+
4861
class QueryResponse(BaseModel):
4962
"""Model representing LLM response to a query.
5063
5164
Attributes:
5265
conversation_id: The optional conversation ID (UUID).
5366
response: The response.
67+
rag_chunks: List of RAG chunks used to generate the response.
68+
referenced_documents: List of documents referenced in the response.
69+
tool_calls: List of tool calls made during response generation.
70+
TODO: truncated: Whether conversation history was truncated.
71+
TODO: input_tokens: Number of tokens sent to LLM.
72+
TODO: output_tokens: Number of tokens received from LLM.
73+
TODO: available_quotas: Quota available as measured by all configured quota limiters
74+
TODO: tool_results: List of tool results.
75+
5476
"""
5577

5678
conversation_id: Optional[str] = Field(
@@ -66,13 +88,48 @@ class QueryResponse(BaseModel):
6688
],
6789
)
6890

91+
rag_chunks: Optional[List[RAGChunk]] = Field(
92+
None,
93+
description="List of RAG chunks used to generate the response",
94+
)
95+
96+
referenced_documents: Optional[List[ReferencedDocument]] = Field(
97+
None,
98+
description="List of documents referenced in the response",
99+
)
100+
101+
tool_calls: Optional[List[ToolCall]] = Field(
102+
None,
103+
description="List of tool calls made during response generation",
104+
)
69105
# provides examples for /docs endpoint
70106
model_config = {
71107
"json_schema_extra": {
72108
"examples": [
73109
{
74110
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
75111
"response": "Operator Lifecycle Manager (OLM) helps users install...",
112+
"rag_chunks": [
113+
{
114+
"content": "OLM is a component of the Operator Framework toolkit...",
115+
"source": "kubernetes-docs/operators.md",
116+
"score": 0.95
117+
}
118+
],
119+
"referenced_documents": [
120+
{
121+
"url": "https://kubernetes.io/docs/concepts/extend-kubernetes/operator/",
122+
"title": "Operator Pattern",
123+
"chunk_count": 2
124+
}
125+
],
126+
"tool_calls": [
127+
{
128+
"tool_name": "knowledge_search",
129+
"arguments": {"query": "operator lifecycle manager"},
130+
"result": {"chunks_found": 5}
131+
}
132+
],
76133
}
77134
]
78135
}

src/utils/types.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Common types for the project."""
22

3-
from typing import Any, Optional
4-
3+
from typing import Any, Optional, List
4+
import json
55
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
66
from llama_stack_client.lib.agents.tool_parser import ToolParser
77
from llama_stack_client.types.shared.completion_message import CompletionMessage
@@ -56,23 +56,86 @@ class ToolCallSummary(BaseModel):
5656
response: str | None
5757

5858

59+
class RAGChunkData(BaseModel):
60+
"""RAG chunk data extracted from tool responses."""
61+
62+
content: str
63+
source: Optional[str] = None
64+
score: Optional[float] = None
65+
66+
5967
class TurnSummary(BaseModel):
6068
"""Summary of a turn in llama stack."""
6169

6270
llm_response: str
6371
tool_calls: list[ToolCallSummary]
72+
rag_chunks: List[RAGChunkData] = []
6473

6574
def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None:
6675
"""Append the tool calls from a llama tool execution step."""
6776
calls_by_id = {tc.call_id: tc for tc in tec.tool_calls}
6877
responses_by_id = {tc.call_id: tc for tc in tec.tool_responses}
6978
for call_id, tc in calls_by_id.items():
7079
resp = responses_by_id.get(call_id)
80+
response_content = interleaved_content_as_str(resp.content) if resp else None
81+
7182
self.tool_calls.append(
7283
ToolCallSummary(
7384
id=call_id,
7485
name=tc.tool_name,
7586
args=tc.arguments,
76-
response=interleaved_content_as_str(resp.content) if resp else None,
87+
response=response_content,
7788
)
7889
)
90+
91+
# Extract RAG chunks from knowledge_search tool responses
92+
if tc.tool_name == "knowledge_search" and resp and response_content:
93+
self._extract_rag_chunks_from_response(response_content)
94+
95+
def _extract_rag_chunks_from_response(self, response_content: str) -> None:
96+
"""Extract RAG chunks from tool response content."""
97+
try:
98+
# Parse the response to get chunks
99+
# Try JSON first
100+
try:
101+
data = json.loads(response_content)
102+
if isinstance(data, dict) and "chunks" in data:
103+
for chunk in data["chunks"]:
104+
self.rag_chunks.append(
105+
RAGChunkData(
106+
content=chunk.get("content", ""),
107+
source=chunk.get("source"),
108+
score=chunk.get("score")
109+
)
110+
)
111+
elif isinstance(data, list):
112+
# Handle list of chunks
113+
for chunk in data:
114+
if isinstance(chunk, dict):
115+
self.rag_chunks.append(
116+
RAGChunkData(
117+
content=chunk.get("content", str(chunk)),
118+
source=chunk.get("source"),
119+
score=chunk.get("score")
120+
)
121+
)
122+
except json.JSONDecodeError:
123+
# If not JSON, treat the entire response as a single chunk
124+
if response_content.strip():
125+
self.rag_chunks.append(
126+
RAGChunkData(
127+
content=response_content,
128+
source="knowledge_search",
129+
score=None
130+
)
131+
)
132+
except Exception:
133+
# Treat response as single chunk
134+
if response_content.strip():
135+
self.rag_chunks.append(
136+
RAGChunkData(
137+
content=response_content,
138+
source="knowledge_search",
139+
score=None
140+
)
141+
)

0 commit comments

Comments
 (0)