Skip to content

Commit 02dbe8a

Browse files
author
Danilo Poccia
committed
fix(citations): Preserve tagged union structure for CitationLocation
The Bedrock API requires CitationLocation to be a tagged union with exactly one wrapper key (documentChar, documentPage, or documentChunk). Changes: - Refactor citation type definitions to use Inner types + wrapper types for proper tagged union modeling - Add _format_citation_location helper method to preserve union structure - Add tests for all document-based citation location types BREAKING CHANGE: CitationLocation types now use wrapper structure (e.g., {'documentChar': {'documentIndex': 0, ...}}) instead of flat structure. This matches the actual Bedrock API specification. Fixes #1323
1 parent 2a02388 commit 02dbe8a

File tree

4 files changed

+196
-33
lines changed

4 files changed

+196
-33
lines changed

src/strands/models/bedrock.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,20 @@
88
import logging
99
import os
1010
import warnings
11-
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast
11+
from typing import (
12+
Any,
13+
AsyncGenerator,
14+
Callable,
15+
Iterable,
16+
Literal,
17+
Mapping,
18+
Optional,
19+
Type,
20+
TypeVar,
21+
Union,
22+
ValuesView,
23+
cast,
24+
)
1225

1326
import boto3
1427
from botocore.config import Config as BotocoreConfig
@@ -493,23 +506,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
493506
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
494507
if "citationsContent" in content:
495508
citations = content["citationsContent"]
496-
result = {}
509+
citations_result: dict[str, Any] = {}
497510

498511
if "citations" in citations:
499-
result["citations"] = []
512+
citations_result["citations"] = []
500513
for citation in citations["citations"]:
501514
filtered_citation: dict[str, Any] = {}
502515
if "location" in citation:
503-
location = citation["location"]
504-
filtered_location = {}
505-
# Filter location fields to only include Bedrock-supported ones
506-
if "documentIndex" in location:
507-
filtered_location["documentIndex"] = location["documentIndex"]
508-
if "start" in location:
509-
filtered_location["start"] = location["start"]
510-
if "end" in location:
511-
filtered_location["end"] = location["end"]
512-
filtered_citation["location"] = filtered_location
516+
filtered_location = self._format_citation_location(citation["location"])
517+
if filtered_location:
518+
filtered_citation["location"] = filtered_location
513519
if "sourceContent" in citation:
514520
filtered_source_content: list[dict[str, Any]] = []
515521
for source_content in citation["sourceContent"]:
@@ -519,20 +525,51 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
519525
filtered_citation["sourceContent"] = filtered_source_content
520526
if "title" in citation:
521527
filtered_citation["title"] = citation["title"]
522-
result["citations"].append(filtered_citation)
528+
citations_result["citations"].append(filtered_citation)
523529

524530
if "content" in citations:
525531
filtered_content: list[dict[str, Any]] = []
526532
for generated_content in citations["content"]:
527533
if "text" in generated_content:
528534
filtered_content.append({"text": generated_content["text"]})
529535
if filtered_content:
530-
result["content"] = filtered_content
536+
citations_result["content"] = filtered_content
531537

532-
return {"citationsContent": result}
538+
return {"citationsContent": citations_result}
533539

534540
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
535541

542+
def _format_citation_location(self, location: Mapping[str, Any]) -> dict[str, Any]:
543+
"""Format a citation location preserving the tagged union structure.
544+
545+
The Bedrock API requires CitationLocation to be a tagged union with exactly one
546+
of the following keys: documentChar, documentPage, or documentChunk.
547+
548+
Args:
549+
location: Citation location to format.
550+
551+
Returns:
552+
Formatted location with tagged union structure preserved, or empty dict if invalid.
553+
554+
See:
555+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
556+
"""
557+
# Allowed fields for each tagged union type
558+
allowed_fields = {
559+
"documentChar": ("documentIndex", "start", "end"),
560+
"documentPage": ("documentIndex", "start", "end"),
561+
"documentChunk": ("documentIndex", "start", "end"),
562+
}
563+
564+
for location_type, fields in allowed_fields.items():
565+
if location_type in location:
566+
inner = location[location_type]
567+
filtered = {k: v for k, v in inner.items() if k in fields}
568+
return {location_type: filtered} if filtered else {}
569+
570+
logger.debug("location_type=<unknown> | unrecognized citation location type, skipping")
571+
return {}
572+
536573
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
537574
"""Check if guardrail data contains any blocked policies.
538575

src/strands/types/citations.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Citation type definitions for the SDK.
22
33
These types are modeled after the Bedrock API.
4+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
45
"""
56

67
from typing import List, Union
@@ -18,11 +19,8 @@ class CitationsConfig(TypedDict):
1819
enabled: bool
1920

2021

21-
class DocumentCharLocation(TypedDict, total=False):
22-
"""Specifies a character-level location within a document.
23-
24-
Provides precise positioning information for cited content using
25-
start and end character indices.
22+
class DocumentCharLocationInner(TypedDict, total=False):
23+
"""Inner content for character-level location within a document.
2624
2725
Attributes:
2826
documentIndex: The index of the document within the array of documents
@@ -38,11 +36,8 @@ class DocumentCharLocation(TypedDict, total=False):
3836
end: int
3937

4038

41-
class DocumentChunkLocation(TypedDict, total=False):
42-
"""Specifies a chunk-level location within a document.
43-
44-
Provides positioning information for cited content using logical
45-
document segments or chunks.
39+
class DocumentChunkLocationInner(TypedDict, total=False):
40+
"""Inner content for chunk-level location within a document.
4641
4742
Attributes:
4843
documentIndex: The index of the document within the array of documents
@@ -58,10 +53,8 @@ class DocumentChunkLocation(TypedDict, total=False):
5853
end: int
5954

6055

61-
class DocumentPageLocation(TypedDict, total=False):
62-
"""Specifies a page-level location within a document.
63-
64-
Provides positioning information for cited content using page numbers.
56+
class DocumentPageLocationInner(TypedDict, total=False):
57+
"""Inner content for page-level location within a document.
6558
6659
Attributes:
6760
documentIndex: The index of the document within the array of documents
@@ -77,7 +70,37 @@ class DocumentPageLocation(TypedDict, total=False):
7770
end: int
7871

7972

80-
# Union type for citation locations
73+
class DocumentCharLocation(TypedDict, total=False):
74+
"""Tagged union wrapper for character-level document location.
75+
76+
Attributes:
77+
documentChar: The character-level location data.
78+
"""
79+
80+
documentChar: DocumentCharLocationInner
81+
82+
83+
class DocumentChunkLocation(TypedDict, total=False):
84+
"""Tagged union wrapper for chunk-level document location.
85+
86+
Attributes:
87+
documentChunk: The chunk-level location data.
88+
"""
89+
90+
documentChunk: DocumentChunkLocationInner
91+
92+
93+
class DocumentPageLocation(TypedDict, total=False):
94+
"""Tagged union wrapper for page-level document location.
95+
96+
Attributes:
97+
documentPage: The page-level location data.
98+
"""
99+
100+
documentPage: DocumentPageLocationInner
101+
102+
103+
# Union type for citation locations - tagged union where exactly one key is present
81104
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
82105

83106

tests/strands/models/test_bedrock.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,3 +2070,106 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model
20702070
"system": [{"text": system_prompt}],
20712071
}
20722072
bedrock_client.converse_stream.assert_called_once_with(**expected_request)
2073+
2074+
2075+
def test_format_request_message_content_document_char_citation(model):
2076+
"""Test that documentChar citations preserve tagged union structure."""
2077+
content = {
2078+
"citationsContent": {
2079+
"citations": [
2080+
{
2081+
"title": "Doc Citation",
2082+
"location": {"documentChar": {"documentIndex": 0, "start": 100, "end": 200}},
2083+
"sourceContent": [{"text": "Excerpt"}],
2084+
}
2085+
],
2086+
"content": [{"text": "Generated text"}],
2087+
}
2088+
}
2089+
2090+
result = model._format_request_message_content(content)
2091+
2092+
assert result["citationsContent"]["citations"][0]["location"] == {
2093+
"documentChar": {"documentIndex": 0, "start": 100, "end": 200}
2094+
}
2095+
2096+
2097+
def test_format_request_message_content_document_page_citation(model):
2098+
"""Test that documentPage citations preserve tagged union structure."""
2099+
content = {
2100+
"citationsContent": {
2101+
"citations": [
2102+
{
2103+
"title": "Page Citation",
2104+
"location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}},
2105+
"sourceContent": [{"text": "Page content"}],
2106+
}
2107+
],
2108+
"content": [{"text": "Generated text"}],
2109+
}
2110+
}
2111+
2112+
result = model._format_request_message_content(content)
2113+
2114+
assert result["citationsContent"]["citations"][0]["location"] == {
2115+
"documentPage": {"documentIndex": 0, "start": 2, "end": 3}
2116+
}
2117+
2118+
2119+
def test_format_request_message_content_document_chunk_citation(model):
2120+
"""Test that documentChunk citations preserve tagged union structure."""
2121+
content = {
2122+
"citationsContent": {
2123+
"citations": [
2124+
{
2125+
"title": "Chunk Citation",
2126+
"location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 10}},
2127+
"sourceContent": [{"text": "Chunk content"}],
2128+
}
2129+
],
2130+
"content": [{"text": "Generated text"}],
2131+
}
2132+
}
2133+
2134+
result = model._format_request_message_content(content)
2135+
2136+
assert result["citationsContent"]["citations"][0]["location"] == {
2137+
"documentChunk": {"documentIndex": 1, "start": 5, "end": 10}
2138+
}
2139+
2140+
2141+
def test_format_request_message_content_citation_filters_extra_fields(model):
2142+
"""Test that extra fields in citation location inner content are filtered out."""
2143+
content = {
2144+
"citationsContent": {
2145+
"citations": [
2146+
{
2147+
"title": "Citation with extra fields",
2148+
"location": {"documentChar": {"documentIndex": 0, "start": 0, "end": 50, "extraField": "ignored"}},
2149+
"sourceContent": [{"text": "Content"}],
2150+
}
2151+
],
2152+
"content": [{"text": "Text"}],
2153+
}
2154+
}
2155+
2156+
result = model._format_request_message_content(content)
2157+
2158+
# extraField should be filtered out
2159+
assert result["citationsContent"]["citations"][0]["location"] == {
2160+
"documentChar": {"documentIndex": 0, "start": 0, "end": 50}
2161+
}
2162+
2163+
2164+
def test_format_request_message_content_citation_unknown_location_type(model):
2165+
"""Test that citations with unknown location types exclude the location field."""
2166+
content = {
2167+
"citationsContent": {
2168+
"citations": [{"title": "Unknown location", "location": {"unknownType": {"field": "value"}}}],
2169+
"content": [{"text": "Text"}],
2170+
}
2171+
}
2172+
2173+
result = model._format_request_message_content(content)
2174+
2175+
assert "location" not in result["citationsContent"]["citations"][0]

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def test_stop_closes_event_loop():
533533
mock_thread.join = MagicMock()
534534
mock_event_loop = MagicMock()
535535
mock_event_loop.close = MagicMock()
536-
536+
537537
client._background_thread = mock_thread
538538
client._background_thread_event_loop = mock_event_loop
539539

@@ -542,7 +542,7 @@ def test_stop_closes_event_loop():
542542

543543
# Verify thread was joined
544544
mock_thread.join.assert_called_once()
545-
545+
546546
# Verify event loop was closed
547547
mock_event_loop.close.assert_called_once()
548548

0 commit comments

Comments
 (0)