Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,55 @@ class DocumentPageLocation(TypedDict, total=False):
end: int


class SearchResultLocation(TypedDict, total=False):
"""Specifies a search result location within the content array.
Provides positioning information for cited content using search result
index and block positions.
Attributes:
searchResultIndex: The index of the search result content block where
the cited content is found. Minimum value of 0.
start: The starting position in the content array where the cited
content begins. Minimum value of 0.
end: The ending position in the content array where the cited
content ends. Minimum value of 0.
"""

searchResultIndex: int
start: int
end: int


class WebLocation(TypedDict, total=False):
"""Provides the URL and domain information for a cited website.
Contains information about the website that was cited when performing
a web search.
Attributes:
url: The URL that was cited when performing a web search.
domain: The domain that was cited when performing a web search.
"""

url: str
domain: str


# Tagged union type aliases following the ToolChoice pattern
DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation]
DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation]
DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation]
SearchResultLocationDict = dict[Literal["searchResultLocation"], SearchResultLocation]
WebLocationDict = dict[Literal["web"], WebLocation]

# Union type for citation locations - tagged union format matching AWS Bedrock API
CitationLocation = Union[
DocumentCharLocationDict,
DocumentPageLocationDict,
DocumentChunkLocationDict,
SearchResultLocationDict,
WebLocationDict,
]


Expand Down
63 changes: 58 additions & 5 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,14 +2078,15 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client
This test verifies that when messages contain citationsContent with tagged union CitationLocation objects,
the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a
tagged union with exactly one wrapper key (documentChar, documentPage, etc.) containing the location fields.
tagged union with exactly one wrapper key (documentChar, documentPage, documentChunk, searchResultLocation, web)
containing the location fields.
"""
# Mock the Bedrock response
bedrock_client.converse_stream.return_value = {"stream": []}

# Messages with citationsContent using tagged union CitationLocation structure
# Messages with citationsContent using all tagged union CitationLocation types
messages = [
{"role": "user", "content": [{"text": "Analyze this document"}]},
{"role": "user", "content": [{"text": "Analyze multiple sources"}]},
{
"role": "assistant",
"content": [
Expand All @@ -2104,8 +2105,34 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client
"sourceContent": [{"text": "Vacation policy allows 15 days per year"}],
"title": "Vacation Policy",
},
{
"location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}},
"sourceContent": [{"text": "Company culture emphasizes work-life balance"}],
"title": "Culture Section",
},
{
"location": {
"searchResultLocation": {
"searchResultIndex": 0,
"start": 25,
"end": 150,
}
},
"sourceContent": [{"text": "Search results show industry best practices"}],
"title": "Search Results",
},
{
"location": {
"web": {
"url": "https://example.com/hr-policies",
"domain": "example.com",
}
},
"sourceContent": [{"text": "External HR policy guidelines"}],
"title": "External Reference",
},
],
"content": [{"text": "Based on the document, employees receive comprehensive benefits."}],
"content": [{"text": "Based on multiple sources, the company offers comprehensive benefits."}],
}
}
],
Expand All @@ -2123,7 +2150,7 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client
formatted_messages = call_args["messages"]
citations_content = formatted_messages[1]["content"][0]["citationsContent"]

# Verify the tagged union structure is preserved
# Verify the tagged union structure is preserved for all location types
expected_citations = [
{
"location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}},
Expand All @@ -2135,6 +2162,32 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client
"sourceContent": [{"text": "Vacation policy allows 15 days per year"}],
"title": "Vacation Policy",
},
{
"location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}},
"sourceContent": [{"text": "Company culture emphasizes work-life balance"}],
"title": "Culture Section",
},
{
"location": {
"searchResultLocation": {
"searchResultIndex": 0,
"start": 25,
"end": 150,
}
},
"sourceContent": [{"text": "Search results show industry best practices"}],
"title": "Search Results",
},
{
"location": {
"web": {
"url": "https://example.com/hr-policies",
"domain": "example.com",
}
},
"sourceContent": [{"text": "External HR policy guidelines"}],
"title": "External Reference",
},
]

assert citations_content["citations"] == expected_citations, (
Expand Down
Loading