Skip to content

Commit 16de31e

Browse files
committed
update anthropic and bedrock to use new typed dicts
1 parent 1e88c56 commit 16de31e

File tree

7 files changed

+84
-84
lines changed

7 files changed

+84
-84
lines changed

ddtrace/contrib/internal/anthropic/utils.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

ddtrace/llmobs/_integrations/anthropic.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
2525
from ddtrace.llmobs._integrations.utils import update_proxy_workflow_input_output_value
2626
from ddtrace.llmobs._utils import _get_attr
27-
from ddtrace.llmobs.utils import ToolCall
27+
from ddtrace.llmobs.utils import Message, ToolCall
2828
from ddtrace.llmobs.utils import ToolDefinition
2929
from ddtrace.llmobs.utils import ToolResult
3030
from ddtrace.trace import Span
@@ -71,7 +71,7 @@ def _llmobs_set_tags(
7171
system_prompt = kwargs.get("system")
7272
input_messages = self._extract_input_message(messages, system_prompt)
7373

74-
output_messages = [{"content": ""}]
74+
output_messages: List[Message] = [Message(content="")]
7575
if not span.error and response is not None:
7676
output_messages = self._extract_output_message(response)
7777
span_kind = "workflow" if span._get_ctx_item(PROXY_REQUEST) else "llm"
@@ -92,14 +92,14 @@ def _llmobs_set_tags(
9292
)
9393
update_proxy_workflow_input_output_value(span, span_kind)
9494

95-
def _extract_input_message(self, messages, system_prompt: Optional[Union[str, List[Dict[str, Any]]]] = None):
95+
def _extract_input_message(self, messages, system_prompt: Optional[Union[str, List[Dict[str, Any]]]] = None) -> List[Message]:
9696
"""Extract input messages from the stored prompt.
9797
Anthropic allows for messages and multiple texts in a message, which requires some special casing.
9898
"""
9999
if not isinstance(messages, Iterable):
100100
log.warning("Anthropic input must be a list of messages.")
101101

102-
input_messages = []
102+
input_messages: List[Message] = []
103103
if system_prompt is not None:
104104
messages = [{"content": system_prompt, "role": "system"}] + messages
105105

@@ -115,43 +115,43 @@ def _extract_input_message(self, messages, system_prompt: Optional[Union[str, Li
115115
log.warning("Anthropic input message must have content and role.")
116116

117117
if isinstance(content, str):
118-
input_messages.append({"content": content, "role": role})
118+
input_messages.append(Message(content=content, role=str(role)))
119119

120120
elif isinstance(content, list):
121121
for block in content:
122122
if _get_attr(block, "type", None) == "text":
123-
input_messages.append({"content": _get_attr(block, "text", ""), "role": role})
123+
input_messages.append(Message(content=_get_attr(block, "text", ""), role=str(role)))
124124

125125
elif _get_attr(block, "type", None) == "image":
126126
# Store a placeholder for potentially enormous binary image data.
127-
input_messages.append({"content": "([IMAGE DETECTED])", "role": role})
127+
input_messages.append(Message(content="([IMAGE DETECTED])", role=str(role)))
128128

129129
elif _get_attr(block, "type", None) == "tool_use":
130130
text = _get_attr(block, "text", None)
131131
input_data = _get_attr(block, "input", "")
132132
if isinstance(input_data, str):
133133
input_data = json.loads(input_data)
134134
tool_call_info = ToolCall(
135-
name=_get_attr(block, "name", ""),
135+
name=str(_get_attr(block, "name", "")),
136136
arguments=input_data,
137-
tool_id=_get_attr(block, "id", ""),
138-
type=_get_attr(block, "type", ""),
137+
tool_id=str(_get_attr(block, "id", "")),
138+
type=str(_get_attr(block, "type", "")),
139139
)
140140
if text is None:
141141
text = ""
142-
input_messages.append({"content": text, "role": role, "tool_calls": [tool_call_info]})
142+
input_messages.append(Message(content=str(text), role=str(role), tool_calls=[tool_call_info]))
143143

144144
elif _get_attr(block, "type", None) == "tool_result":
145145
content = _get_attr(block, "content", None)
146146
formatted_content = self._format_tool_result_content(content)
147147
tool_result_info = ToolResult(
148148
result=formatted_content,
149-
tool_id=_get_attr(block, "tool_use_id", ""),
149+
tool_id=str(_get_attr(block, "tool_use_id", "")),
150150
type="tool_result",
151151
)
152-
input_messages.append({"content": "", "role": role, "tool_results": [tool_result_info]})
152+
input_messages.append(Message(content="", role=str(role), tool_results=[tool_result_info]))
153153
else:
154-
input_messages.append({"content": str(block), "role": role})
154+
input_messages.append(Message(content=str(block), role=str(role)))
155155

156156
return input_messages
157157

@@ -169,34 +169,34 @@ def _format_tool_result_content(self, content) -> str:
169169
return ",".join(formatted_content)
170170
return str(content)
171171

172-
def _extract_output_message(self, response):
172+
def _extract_output_message(self, response) -> List[Message]:
173173
"""Extract output messages from the stored response."""
174-
output_messages = []
174+
output_messages: List[Message] = []
175175
content = _get_attr(response, "content", "")
176176
role = _get_attr(response, "role", "")
177177

178178
if isinstance(content, str):
179-
return [{"content": content, "role": role}]
179+
return [Message(content=content, role=str(role))]
180180

181181
elif isinstance(content, list):
182182
for completion in content:
183183
text = _get_attr(completion, "text", None)
184184
if isinstance(text, str):
185-
output_messages.append({"content": text, "role": role})
185+
output_messages.append(Message(content=text, role=str(role)))
186186
else:
187187
if _get_attr(completion, "type", None) == "tool_use":
188188
input_data = _get_attr(completion, "input", "")
189189
if isinstance(input_data, str):
190190
input_data = json.loads(input_data)
191191
tool_call_info = ToolCall(
192-
name=_get_attr(completion, "name", ""),
192+
name=str(_get_attr(completion, "name", "")),
193193
arguments=input_data,
194-
tool_id=_get_attr(completion, "id", ""),
195-
type=_get_attr(completion, "type", ""),
194+
tool_id=str(_get_attr(completion, "id", "")),
195+
type=str(_get_attr(completion, "type", "")),
196196
)
197197
if text is None:
198198
text = ""
199-
output_messages.append({"content": text, "role": role, "tool_calls": [tool_call_info]})
199+
output_messages.append(Message(content=str(text), role=str(role), tool_calls=[tool_call_info]))
200200
return output_messages
201201

202202
def _extract_usage(self, span: Span, usage: Dict[str, Any]):

ddtrace/llmobs/_integrations/bedrock.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ddtrace.llmobs._telemetry import record_bedrock_agent_span_event_created
3636
from ddtrace.llmobs._utils import _get_attr
3737
from ddtrace.llmobs._writer import LLMObsSpanEvent
38-
from ddtrace.llmobs.utils import ToolDefinition
38+
from ddtrace.llmobs.utils import Message, ToolDefinition
3939
from ddtrace.trace import Span
4040

4141

@@ -191,7 +191,7 @@ def translate_bedrock_traces(self, traces, root_span) -> None:
191191
self._active_span_by_step_id.clear()
192192

193193
@staticmethod
194-
def _extract_input_message_for_converse(prompt: List[Dict[str, Any]]):
194+
def _extract_input_message_for_converse(prompt: List[Dict[str, Any]]) -> List[Message]:
195195
"""Extract input messages from the stored prompt for converse
196196
197197
`prompt` is an array of `message` objects. Each `message` has a role and content field.
@@ -203,7 +203,7 @@ def _extract_input_message_for_converse(prompt: List[Dict[str, Any]]):
203203
"""
204204
if not isinstance(prompt, list):
205205
log.warning("Bedrock input is not a list of messages or a string.")
206-
return [{"content": ""}]
206+
return [Message(content="")]
207207
input_messages = []
208208
for message in prompt:
209209
if not isinstance(message, dict):
@@ -342,41 +342,42 @@ def _converse_output_stream_processor() -> (
342342
return messages, metadata, usage_metrics
343343

344344
@staticmethod
345-
def _extract_input_message(prompt):
345+
def _extract_input_message(prompt) -> List[Message]:
346346
"""Extract input messages from the stored prompt.
347347
Anthropic allows for messages and multiple texts in a message, which requires some special casing.
348348
"""
349349
if isinstance(prompt, str):
350-
return [{"content": prompt}]
350+
return [Message(content=prompt)]
351351
if not isinstance(prompt, list):
352352
log.warning("Bedrock input is not a list of messages or a string.")
353-
return [{"content": ""}]
354-
input_messages = []
353+
return [Message(content="")]
354+
input_messages: List[Message] = []
355355
for p in prompt:
356356
content = p.get("content", "")
357357
if isinstance(content, list) and isinstance(content[0], dict):
358358
for entry in content:
359359
if entry.get("type") == "text":
360-
input_messages.append({"content": entry.get("text", ""), "role": str(p.get("role", ""))})
360+
input_messages.append(Message(content=entry.get("text", ""), role=str(p.get("role", ""))))
361361
elif entry.get("type") == "image":
362362
# Store a placeholder for potentially enormous binary image data.
363-
input_messages.append({"content": "([IMAGE DETECTED])", "role": str(p.get("role", ""))})
363+
input_messages.append(Message(content="([IMAGE DETECTED])", role=str(p.get("role", ""))))
364364
else:
365-
input_messages.append({"content": content, "role": str(p.get("role", ""))})
365+
input_messages.append(Message(content=str(content), role=str(p.get("role", ""))))
366366
return input_messages
367367

368368
@staticmethod
369-
def _extract_output_message(response):
369+
def _extract_output_message(response) -> List[Message]:
370370
"""Extract output messages from the stored response.
371371
Anthropic allows for chat messages, which requires some special casing.
372372
"""
373373
if isinstance(response["text"], str):
374-
return [{"content": response["text"]}]
374+
return [Message(content=response["text"])]
375375
if isinstance(response["text"], list):
376376
if isinstance(response["text"][0], str):
377-
return [{"content": str(content)} for content in response["text"]]
377+
return [Message(content=str(content)) for content in response["text"]]
378378
if isinstance(response["text"][0], dict):
379-
return [{"content": response["text"][0].get("text", "")}]
379+
return [Message(content=response["text"][0].get("text", ""))]
380+
return []
380381

381382
def _get_base_url(self, **kwargs: Dict[str, Any]) -> Optional[str]:
382383
instance = kwargs.get("instance")
@@ -396,8 +397,8 @@ def _extract_tool_definitions(self, tool_config: Dict[str, Any]) -> List[ToolDef
396397
for tool in tools:
397398
tool_spec = _get_attr(tool, "toolSpec", {})
398399
tool_definition_info = ToolDefinition(
399-
name=_get_attr(tool_spec, "name", ""),
400-
description=_get_attr(tool_spec, "description", ""),
400+
name=str(_get_attr(tool_spec, "name", "")),
401+
description=str(_get_attr(tool_spec, "description", "")),
401402
schema=_get_attr(tool_spec, "inputSchema", {}),
402403
)
403404
tool_definitions.append(tool_definition_info)

0 commit comments

Comments
 (0)