Skip to content

Adding bedrock chat completion for anthropic models #6170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from ._anthropic_client import AnthropicChatCompletionClient, BaseAnthropicChatCompletionClient
from .config import (
AnthropicBedrockClientConfiguration,
AnthropicBedrockClientConfigurationConfigModel,
AnthropicClientConfiguration,
AnthropicClientConfigurationConfigModel,
BedrockInfo,
CreateArgumentsConfigModel,
)

__all__ = [
"AnthropicChatCompletionClient",
"BaseAnthropicChatCompletionClient",
"AnthropicClientConfiguration",
"AnthropicBedrockClientConfiguration",
"AnthropicClientConfigurationConfigModel",
"AnthropicBedrockClientConfigurationConfigModel",
"CreateArgumentsConfigModel",
"BedrockInfo",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

import tiktoken
from anthropic import AsyncAnthropic, AsyncStream
from anthropic import AnthropicBedrock, AsyncAnthropic, AsyncStream
from anthropic.types import (
Base64ImageSourceParam,
ContentBlock,
Expand Down Expand Up @@ -67,7 +67,13 @@
from typing_extensions import Self, Unpack

from . import _model_info
from .config import AnthropicClientConfiguration, AnthropicClientConfigurationConfigModel
from .config import (
AnthropicBedrockClientConfiguration,
AnthropicBedrockClientConfigurationConfigModel,
AnthropicClientConfiguration,
AnthropicClientConfigurationConfigModel,
BedrockInfo,
)

logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
Expand Down Expand Up @@ -410,7 +416,7 @@
class BaseAnthropicChatCompletionClient(ChatCompletionClient):
def __init__(
self,
client: AsyncAnthropic,
client: Any,
*,
create_args: Dict[str, Any],
model_info: Optional[ModelInfo] = None,
Expand Down Expand Up @@ -1109,3 +1115,138 @@
copied_config["api_key"] = config.api_key.get_secret_value()

return cls(**copied_config)


class AnthropicBedrockChatCompletionClient(
BaseAnthropicChatCompletionClient, Component[AnthropicBedrockClientConfigurationConfigModel]
):
"""
Chat completion client for Anthropic's Claude models.

Args:
model (str): The Claude model to use (e.g., "claude-3-sonnet-20240229", "claude-3-opus-20240229")
api_key (str, optional): Anthropic API key. Required if not in environment variables.
base_url (str, optional): Override the default API endpoint.
max_tokens (int, optional): Maximum tokens in the response. Default is 4096.
temperature (float, optional): Controls randomness. Lower is more deterministic. Default is 1.0.
top_p (float, optional): Controls diversity via nucleus sampling. Default is 1.0.
top_k (int, optional): Controls diversity via top-k sampling. Default is -1 (disabled).
model_info (ModelInfo, optional): The capabilities of the model. Required if using a custom model.
bedrock_info (BedrockInfo, optional): The capabilities of the model in bedrock. Required if using a model from AWS bedrock.

To use this client, you must install the Anthropic extension:

.. code-block:: bash

pip install "autogen-ext[anthropic]"

Example:

.. code-block:: python

import asyncio
from autogen_ext.models.anthropic import AnthropicBedrockChatCompletionClient
from autogen_core.models import UserMessage


async def main():
config = {
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"temperature": 0.1,
"model_info": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
},
"bedrock_info": {
"aws_access_key": "<aws_access_key>",
"aws_secret_key": "<aws_secret_key>",
"aws_session_token": "<aws_session_token>",
"aws_region": "<aws_region>",
},
}
anthropic_client = AnthropicBedrockChatCompletionClient(**config)

result = await anthropic_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
print(result)


if __name__ == "__main__":
asyncio.run(main())
"""

component_type = "model"
component_config_schema = AnthropicBedrockClientConfigurationConfigModel
component_provider_override = "autogen_ext.models.anthropic.AnthropicChatCompletionClient"

def __init__(self, **kwargs: Unpack[AnthropicBedrockClientConfiguration]):
if "model" not in kwargs:
raise ValueError("model is required for AnthropicChatCompletionClient")

Check warning on line 1185 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1184-L1185

Added lines #L1184 - L1185 were not covered by tests

self._raw_config: Dict[str, Any] = dict(kwargs).copy()
copied_args = dict(kwargs).copy()

Check warning on line 1188 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1187-L1188

Added lines #L1187 - L1188 were not covered by tests

model_info: Optional[ModelInfo] = None
if "model_info" in kwargs:
model_info = kwargs["model_info"]
del copied_args["model_info"]

Check warning on line 1193 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1190-L1193

Added lines #L1190 - L1193 were not covered by tests

bedrock_info: Optional[BedrockInfo] = None
if "bedrock_info" in kwargs:
bedrock_info = kwargs["bedrock_info"]

Check warning on line 1197 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1195-L1197

Added lines #L1195 - L1197 were not covered by tests

if bedrock_info is None:
raise ValueError("bedrock_info is required for AnthropicBedrockChatCompletionClient")

Check warning on line 1200 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1199-L1200

Added lines #L1199 - L1200 were not covered by tests

# Handle bedrock_info as secretestr
aws_region = bedrock_info["aws_region"]
aws_access_key = bedrock_info["aws_access_key"].get_secret_value()
aws_secret_key = bedrock_info["aws_secret_key"].get_secret_value()
aws_session_token = bedrock_info["aws_session_token"].get_secret_value()

Check warning on line 1206 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1203-L1206

Added lines #L1203 - L1206 were not covered by tests

client = AnthropicBedrock(

Check warning on line 1208 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1208

Added line #L1208 was not covered by tests
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
aws_region=aws_region,
)
create_args = _create_args_from_config(copied_args)

Check warning on line 1214 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1214

Added line #L1214 was not covered by tests

super().__init__(

Check warning on line 1216 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1216

Added line #L1216 was not covered by tests
client=client,
create_args=create_args,
model_info=model_info,
)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_client"] = None
return state

Check warning on line 1225 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1223-L1225

Added lines #L1223 - L1225 were not covered by tests

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self._client = _anthropic_client_from_config(state["_raw_config"])

Check warning on line 1229 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1228-L1229

Added lines #L1228 - L1229 were not covered by tests

def _to_config(self) -> AnthropicBedrockClientConfigurationConfigModel:
copied_config = self._raw_config.copy()
return AnthropicBedrockClientConfigurationConfigModel(**copied_config)

Check warning on line 1233 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1232-L1233

Added lines #L1232 - L1233 were not covered by tests

@classmethod
def _from_config(cls, config: AnthropicBedrockClientConfigurationConfigModel) -> Self:
copied_config = config.model_copy().model_dump(exclude_none=True)

Check warning on line 1237 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1237

Added line #L1237 was not covered by tests

# Handle api_key as SecretStr
if "api_key" in copied_config and isinstance(config.api_key, SecretStr):
copied_config["api_key"] = config.api_key.get_secret_value()

Check warning on line 1241 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1240-L1241

Added lines #L1240 - L1241 were not covered by tests

# Handle bedrock_info as SecretStr
if "bedrock_info" in copied_config and isinstance(config.bedrock_info, dict):
copied_config["bedrock_info"] = {

Check warning on line 1245 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1244-L1245

Added lines #L1244 - L1245 were not covered by tests
"aws_access_key": config.bedrock_info["aws_access_key"].get_secret_value(),
"aws_secret_key": config.bedrock_info["aws_secret_key"].get_secret_value(),
"aws_session_token": config.bedrock_info["aws_session_token"].get_secret_value(),
"aws_region": config.bedrock_info["aws_region"],
}

return cls(**copied_config)

Check warning on line 1252 in python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py#L1252

Added line #L1252 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from autogen_core.models import ModelCapabilities, ModelInfo # type: ignore
from pydantic import BaseModel, SecretStr
from typing_extensions import TypedDict
from typing_extensions import Required, TypedDict


class ResponseFormat(TypedDict):
Expand All @@ -20,6 +20,22 @@ class CreateArguments(TypedDict, total=False):
metadata: Optional[Dict[str, str]]


class BedrockInfo(TypedDict):
"""BedrockInfo is a dictionary that contains information about a bedrock's properties.
It is expected to be used in the bedrock_info property of a model client.

"""

aws_access_key: Required[SecretStr]
"""Access key for the aws account to gain bedrock model access"""
aws_secret_key: Required[SecretStr]
"""Access secret key for the aws account to gain bedrock model access"""
aws_session_token: Required[SecretStr]
"""aws session token for the aws account to gain bedrock model access"""
aws_region: Required[str]
"""aws region for the aws account to gain bedrock model access"""


class BaseAnthropicClientConfiguration(CreateArguments, total=False):
api_key: str
base_url: Optional[str]
Expand All @@ -36,6 +52,10 @@ class AnthropicClientConfiguration(BaseAnthropicClientConfiguration, total=False
tool_choice: Optional[Union[Literal["auto", "any", "none"], Dict[str, Any]]]


class AnthropicBedrockClientConfiguration(AnthropicClientConfiguration, total=False):
bedrock_info: BedrockInfo


# Pydantic equivalents of the above TypedDicts
class CreateArgumentsConfigModel(BaseModel):
model: str
Expand All @@ -61,3 +81,7 @@ class BaseAnthropicClientConfigurationConfigModel(CreateArgumentsConfigModel):
class AnthropicClientConfigurationConfigModel(BaseAnthropicClientConfigurationConfigModel):
tools: List[Dict[str, Any]] | None = None
tool_choice: Union[Literal["auto", "any", "none"], Dict[str, Any]] | None = None


class AnthropicBedrockClientConfigurationConfigModel(AnthropicClientConfigurationConfigModel):
bedrock_info: BedrockInfo | None = None
Loading